[PyTorch] Autoformat c10 (#56830)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56830

Opt into formatting on GitHub and format everything. This is a trial run before turning on formatting for more and eventually all of the codebase.

Test Plan: CI

Reviewed By: zertosh

Differential Revision: D27979080

fbshipit-source-id: a80f0c48691c08ae8ca0af06377b87e6a2351151
This commit is contained in:
Scott Wolchok 2021-04-30 21:22:23 -07:00 committed by Facebook GitHub Bot
parent 3c4d57c18b
commit 44cc873fba
162 changed files with 17273 additions and 13736 deletions

View File

@ -18,7 +18,6 @@ class Foo : public intrusive_ptr_target {
int param;
};
class Bar : public std::enable_shared_from_this<Bar> {
public:
Bar(int param_) : param(param_) {}
@ -48,7 +47,7 @@ BENCHMARK(BM_SharedPtrCtorDtor);
static void BM_IntrusivePtrArray(benchmark::State& state) {
intrusive_ptr<Foo> var = make_intrusive<Foo>(0);
const size_t kLength = state.range(0);
std::vector<intrusive_ptr<Foo> > vararray(kLength);
std::vector<intrusive_ptr<Foo>> vararray(kLength);
while (state.KeepRunning()) {
for (const auto i : c10::irange(kLength)) {
vararray[i] = var;
@ -64,7 +63,7 @@ BENCHMARK(BM_IntrusivePtrArray)->RangeMultiplier(2)->Range(16, 4096);
static void BM_SharedPtrArray(benchmark::State& state) {
std::shared_ptr<Bar> var = std::make_shared<Bar>(0);
const size_t kLength = state.range(0);
std::vector<std::shared_ptr<Bar> > vararray(kLength);
std::vector<std::shared_ptr<Bar>> vararray(kLength);
while (state.KeepRunning()) {
for (const auto i : c10::irange(kLength)) {
vararray[i] = var;
@ -78,5 +77,4 @@ static void BM_SharedPtrArray(benchmark::State& state) {
BENCHMARK(BM_SharedPtrArray)->RangeMultiplier(2)->Range(16, 4096);
} // namespace
BENCHMARK_MAIN();

View File

@ -12,10 +12,11 @@ at::DataPtr InefficientStdFunctionContext::makeDataPtr(
void* ptr,
const std::function<void(void*)>& deleter,
Device device) {
return {ptr,
new InefficientStdFunctionContext({ptr, deleter}),
&deleteInefficientStdFunctionContext,
device};
return {
ptr,
new InefficientStdFunctionContext({ptr, deleter}),
&deleteInefficientStdFunctionContext,
device};
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)

View File

@ -94,7 +94,9 @@ class C10_API DataPtr {
* be; be sure to read the source code of the Allocator
* in question to confirm this.
*/
C10_NODISCARD bool compare_exchange_deleter(DeleterFnPtr expected_deleter, DeleterFnPtr new_deleter) {
C10_NODISCARD bool compare_exchange_deleter(
DeleterFnPtr expected_deleter,
DeleterFnPtr new_deleter) {
return ptr_.compare_exchange_deleter(expected_deleter, new_deleter);
}
Device device() const {
@ -215,8 +217,8 @@ struct AllocatorRegisterer {
}
};
#define REGISTER_ALLOCATOR(t, f) \
namespace { \
#define REGISTER_ALLOCATOR(t, f) \
namespace { \
static AllocatorRegisterer<t> g_allocator_d(f); \
}
@ -227,12 +229,18 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
virtual ~MemoryReportingInfoBase() {}
// Negative alloc_size corresponds to freeing of the memory
virtual void reportMemoryUsage(void* ptr, int64_t alloc_size, Device device) = 0;
virtual void reportMemoryUsage(
void* ptr,
int64_t alloc_size,
Device device) = 0;
virtual bool memoryProfilingEnabled() const = 0;
};
C10_API bool memoryProfilingEnabled();
C10_API void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device);
C10_API void reportMemoryUsageToProfiler(
void* ptr,
int64_t alloc_size,
Device device);
} // namespace c10

View File

@ -253,7 +253,7 @@ static inline bool isSparse(Backend b) {
}
static inline bool isSparseCsr(Backend b) {
switch(b) {
switch (b) {
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
return true;

View File

@ -29,9 +29,11 @@ namespace c10 {
* }
* EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
*/
template<class FuncType_, FuncType_* func_ptr_>
template <class FuncType_, FuncType_* func_ptr_>
struct CompileTimeFunctionPointer final {
static_assert(guts::is_function_type<FuncType_>::value, "TORCH_FN can only wrap function types.");
static_assert(
guts::is_function_type<FuncType_>::value,
"TORCH_FN can only wrap function types.");
using FuncType = FuncType_;
static constexpr FuncType* func_ptr() {
@ -39,11 +41,16 @@ struct CompileTimeFunctionPointer final {
}
};
template<class T> struct is_compile_time_function_pointer : std::false_type {};
template<class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
}
} // namespace c10
#define TORCH_FN_TYPE(func) ::c10::CompileTimeFunctionPointer<std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, func>
#define TORCH_FN_TYPE(func) \
::c10::CompileTimeFunctionPointer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
func>
#define TORCH_FN(func) TORCH_FN_TYPE(func)()

View File

@ -47,4 +47,4 @@ void CopyBytes(
ptr(nbytes, src, src_device, dst, dst_device);
}
}
} // namespace c10

View File

@ -1,5 +1,5 @@
#include <c10/util/typeid.h>
#include <c10/core/DefaultDtype.h>
#include <c10/util/typeid.h>
namespace c10 {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -7,7 +7,8 @@ static auto default_dtype = caffe2::TypeMeta::Make<float>();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto default_dtype_as_scalartype = default_dtype.toScalarType();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto default_complex_dtype = caffe2::TypeMeta::Make<c10::complex<float>>();
static auto default_complex_dtype =
caffe2::TypeMeta::Make<c10::complex<float>>();
void set_default_dtype(caffe2::TypeMeta dtype) {
default_dtype = dtype;

View File

@ -1,7 +1,7 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
namespace caffe2 {
class TypeMeta;

View File

@ -13,19 +13,27 @@ struct TensorOptions;
struct DefaultTensorOptions {
DefaultTensorOptions() = default;
caffe2::TypeMeta dtype() const noexcept { return dtype_; }
Device device() const noexcept { return device_; }
Layout layout() const noexcept { return layout_; }
bool requires_grad() const noexcept { return requires_grad_; }
caffe2::TypeMeta dtype() const noexcept {
return dtype_;
}
Device device() const noexcept {
return device_;
}
Layout layout() const noexcept {
return layout_;
}
bool requires_grad() const noexcept {
return requires_grad_;
}
// Defined in TensorOptions.h
inline DefaultTensorOptions& merge(const TensorOptions& options);
private:
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
Device device_ = at::kCPU; // 32-bit
Layout layout_ = at::kStrided; // 8-bit
bool requires_grad_ = false; // 8-bit
Device device_ = at::kCPU; // 32-bit
Layout layout_ = at::kStrided; // 8-bit
bool requires_grad_ = false; // 8-bit
};
inline const DefaultTensorOptions& getDefaultTensorOptions() {

View File

@ -6,25 +6,24 @@
#include <array>
#include <exception>
#include <ostream>
#include <regex>
#include <string>
#include <tuple>
#include <vector>
#include <regex>
// Check if compiler has working std::regex implementation
//
// Test below is adapted from https://stackoverflow.com/a/41186162
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
// Compiler has working regex. MSVC has erroneous __cplusplus.
// Compiler has working regex. MSVC has erroneous __cplusplus.
#elif __cplusplus >= 201103L && \
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
(defined(_GLIBCXX_RELEASE) && \
_GLIBCXX_RELEASE > 4)))
// Compiler has working regex.
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
(defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4)))
// Compiler has working regex.
#else
static_assert(false, "Compiler does not have proper regex support.");
static_assert(false, "Compiler does not have proper regex support.");
#endif
namespace c10 {
@ -58,7 +57,8 @@ DeviceType parse_type(const std::string& device_string) {
if (device != types.end()) {
return device->second;
}
TORCH_CHECK(false,
TORCH_CHECK(
false,
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ",
device_string);
}
@ -71,16 +71,22 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
std::smatch match;
TORCH_CHECK(
std::regex_match(device_string, match, regex),
"Invalid device string: '", device_string, "'");
std::regex_match(device_string, match, regex),
"Invalid device string: '",
device_string,
"'");
type_ = parse_type(match[1].str());
if (match[2].matched) {
try {
index_ = c10::stoi(match[2].str());
} catch (const std::exception &) {
TORCH_CHECK(false,
"Could not parse device index '", match[2].str(),
"' in device string '", device_string, "'");
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
match[2].str(),
"' in device string '",
device_string,
"'");
}
}
validate();

View File

@ -107,16 +107,18 @@ struct C10_API Device final {
// performance in micro-benchmarks.
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0,
"Device index must be -1 or non-negative, got ", (int)index_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ", (int)index_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ == -1 || index_ >= 0,
"Device index must be -1 or non-negative, got ",
(int)index_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
(int)index_);
}
};
C10_API std::ostream& operator<<(
std::ostream& stream,
const Device& device);
C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
} // namespace c10
@ -136,10 +138,11 @@ struct hash<c10::Device> {
// half of the resulting integer.
//
// Technically, by C/C++ integer promotion rules, we only need one of the
// uint32_t casts to the result type, but we put in both for explicitness's sake.
uint32_t bits =
static_cast<uint32_t>(static_cast<uint8_t>(d.type())) << 16
| static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
// uint32_t casts to the result type, but we put in both for explicitness's
// sake.
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
};

View File

@ -17,7 +17,7 @@ namespace c10 {
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
public:
public:
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit DeviceGuard() = delete;
@ -25,7 +25,10 @@ public:
explicit DeviceGuard(Device device) : guard_(device) {}
/// This constructor is for testing only.
explicit DeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
explicit DeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
@ -48,7 +51,9 @@ public:
}
/// This method is for testing only.
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
@ -69,7 +74,7 @@ public:
return guard_.current_device();
}
private:
private:
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};
@ -79,8 +84,8 @@ private:
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
* with extra constructors and methods as appropriate.
*
* Besides its obvious use (optionally applying a DeviceGuard), OptionalDeviceGuard
* is often also used for the following idiom:
* Besides its obvious use (optionally applying a DeviceGuard),
* OptionalDeviceGuard is often also used for the following idiom:
*
* OptionalDeviceGuard g;
* for (const auto& t : tensors) {
@ -117,7 +122,7 @@ private:
* DeviceGuard will still reset the device to original_device_.
*/
class OptionalDeviceGuard {
public:
public:
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() : guard_() {}
@ -129,7 +134,10 @@ public:
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
/// Constructor for testing only.
explicit OptionalDeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
explicit OptionalDeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
@ -149,7 +157,9 @@ public:
}
/// For testing only
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
@ -164,7 +174,7 @@ public:
return guard_.current_device();
}
private:
private:
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_;
};
@ -173,7 +183,8 @@ private:
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard = impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard =
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on. Furthermore, for specializations

View File

@ -38,7 +38,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
case DeviceType::Meta:
return lower_case ? "meta" : "META";
default:
TORCH_CHECK(false,
TORCH_CHECK(
false,
"Unknown device: ",
static_cast<int16_t>(d),
". If you have recently updated the caffe2.proto file to add a new "

View File

@ -7,8 +7,8 @@
#include <c10/macros/Macros.h>
#include <ostream>
#include <functional>
#include <ostream>
namespace c10 {
@ -51,7 +51,8 @@ constexpr DeviceType kXPU = DeviceType::XPU;
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
@ -61,9 +62,7 @@ static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");
C10_API std::string DeviceTypeName(
DeviceType d,
bool lower_case = false);
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
C10_API bool isValidDeviceType(DeviceType d);
@ -72,7 +71,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
} // namespace c10
namespace std {
template <> struct hash<c10::DeviceType> {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
@ -80,5 +80,5 @@ template <> struct hash<c10::DeviceType> {
} // namespace std
namespace torch {
using c10::DeviceType;
using c10::DeviceType;
}

View File

@ -1,11 +1,11 @@
#pragma once
#include <vector>
#include <iostream>
#include <string>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <iostream>
#include <string>
#include <vector>
namespace c10 {
@ -58,7 +58,7 @@ enum class DispatchKey : uint8_t {
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
// https://gitlab.com/pytorch-complex/vitis_kernels
MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp
XLA, // lives out of tree at https://github.com/pytorch/xla
@ -177,7 +177,8 @@ enum class DispatchKey : uint8_t {
// But this work is currently blocked since it adds an extra dispatch
// for all ops and it's non-trivial overhead at model level(a few percents).
// Thus our current approach takes advantage of the fact every kernel go
// through VariableType kernel first and pulls the `at::AutoDispatchBelowInplaceOrView` guard of functional ops
// through VariableType kernel first and pulls the
// `at::AutoDispatchBelowInplaceOrView` guard of functional ops
// up to the `VariableType` kernel. Thus we only add the extra dispatch
// to view/inplace ops to minimize its perf impact to real models.
InplaceOrView,
@ -213,7 +214,8 @@ enum class DispatchKey : uint8_t {
AutogradXLA,
AutogradXPU,
AutogradMLC,
AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
AutogradNestedTensor, // lives out of tree at
// https://github.com/pytorch/nestedtensor
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,
@ -224,7 +226,7 @@ enum class DispatchKey : uint8_t {
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
// and inputs are saved for backward in the post-autocast type.
//AutocastCPU,
// AutocastCPU,
AutocastCUDA,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
@ -278,9 +280,10 @@ enum class DispatchKey : uint8_t {
// See Note [Alias Dispatch Key : Autograd]
Autograd,
CompositeImplicitAutograd, // registered at build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
CompositeImplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
CompositeExplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
// Define an alias key to represent end of alias dispatch keys.
// If you add new alias keys after Autograd, please also update it here.
@ -316,15 +319,15 @@ enum class DispatchKey : uint8_t {
// We provide two classes of private user tensor id: regular DispatchKeys
// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend"
// DispatchKeys; if you were adding support for a new type of accelerator, you
// would use a backend DispatchKey, and ideally automatically reuse AutogradOther
// definitions already defined in PyTorch. AutogradPrivateUse DispatchKeys serve
// as "wrapper" DispatchKeys: they are only necessary for tensors that compose
// multiple internal tensors, and for cases when the built-in autograd formulas
// for operators are not appropriate.
// would use a backend DispatchKey, and ideally automatically reuse
// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse
// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for
// tensors that compose multiple internal tensors, and for cases when the
// built-in autograd formulas for operators are not appropriate.
static_assert(
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
C10_API const char* toString(DispatchKey);
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
@ -345,10 +348,10 @@ inline bool isAliasDispatchKey(DispatchKey k) {
} // namespace c10
namespace torch {
// Expose the constant, but not the TYPE (DispatchKey is an implementation
// detail!)
using c10::kAutograd;
}
// Expose the constant, but not the TYPE (DispatchKey is an implementation
// detail!)
using c10::kAutograd;
} // namespace torch
// NB: You really shouldn't use this instance; this enum is guaranteed
// to be pretty small so a regular array should be acceptable.
@ -362,4 +365,4 @@ struct hash<c10::DispatchKey> {
return static_cast<size_t>(x);
}
};
}
} // namespace std

View File

@ -3,9 +3,10 @@
namespace c10 {
// backend_dispatch_keyset should include all runtime backend keys.
// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset
// NestedTensor has been explicitly removed due to incompatibility with some
// kernels, such as structured kernels, that use the DefaultBackend key.
// Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset NestedTensor has been explicitly removed due to
// incompatibility with some kernels, such as structured kernels, that use the
// DefaultBackend key.
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
@ -23,9 +24,11 @@ bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset =
backend_dispatch_keyset | autograd_dispatch_keyset;
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
@ -41,8 +44,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
// for a non-autograd key, return the empty keyset.
// for a given autograd key, return the (guaranteed nonempty) set of associated
// backend keys. for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
@ -72,7 +75,7 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
//case DispatchKey::CPU:
// case DispatchKey::CPU:
// return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
@ -82,8 +85,8 @@ DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet({
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
return DispatchKeySet(
{DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
@ -116,4 +119,4 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
return os;
}
}
} // namespace c10

View File

@ -1,9 +1,9 @@
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/util/llvmMathExtras.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/llvmMathExtras.h>
#include <ostream>
namespace c10 {
@ -33,30 +33,29 @@ namespace c10 {
//
// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
public:
public:
enum Full { FULL };
enum FullAfter { FULL_AFTER };
enum Raw { RAW };
// NB: default constructor representation as zero is MANDATORY as
// use of DispatchKeySet in TLS requires this.
constexpr DispatchKeySet()
: repr_(0) {}
constexpr DispatchKeySet() : repr_(0) {}
constexpr DispatchKeySet(Full)
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
constexpr DispatchKeySet(FullAfter, DispatchKey t)
// LSB after t are OK, but not t itself.
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
// LSB after t are OK, but not t itself.
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
// Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this!
constexpr DispatchKeySet(Raw, uint64_t x)
: repr_(x) {}
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
explicit constexpr DispatchKeySet(DispatchKey t)
: repr_(t == DispatchKey::Undefined
? 0
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
: repr_(
t == DispatchKey::Undefined
? 0
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
: repr_(0) {
: repr_(0) {
for (auto k : ks) {
repr_ |= DispatchKeySet(k).repr_;
}
@ -105,7 +104,9 @@ public:
bool empty() const {
return repr_ == 0;
}
uint64_t raw_repr() { return repr_; }
uint64_t raw_repr() {
return repr_;
}
// Return the type id in this set with the highest priority (i.e.,
// is the largest in the DispatchKey enum). Intuitively, this
// type id is the one that should handle dispatch (assuming there
@ -119,18 +120,20 @@ public:
}
DispatchKey highestPriorityBackendTypeId() const {
return (*this & ((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
.highestPriorityTypeId();
return (*this &
((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
.highestPriorityTypeId();
}
private:
private:
constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
uint64_t repr_ = 0;
public:
public:
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
// set. The iterator is only invalidated by the destruction of the underlying
// DispatchKeySet as the iterator stores a pointer to the raw representation of
// the DispatchKeySet.
// DispatchKeySet as the iterator stores a pointer to the raw representation
// of the DispatchKeySet.
class iterator {
public:
using self_type = iterator;
@ -138,13 +141,15 @@ public:
using value_type = DispatchKey;
using difference_type = ptrdiff_t;
explicit iterator(const uint64_t *data_ptr, uint8_t i=0) : data_ptr_(data_ptr), i_(i) {
explicit iterator(const uint64_t* data_ptr, uint8_t i = 0)
: data_ptr_(data_ptr), i_(i) {
// Go to the first key in the set
++(*this);
}
self_type& operator++() {
TORCH_INTERNAL_ASSERT(i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
TORCH_INTERNAL_ASSERT(
i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
@ -153,7 +158,7 @@ public:
// If there are no keys, set to end iterator value
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
return *this;
}
@ -163,29 +168,38 @@ public:
}
self_type operator++(int) {
self_type previous_iterator = *this;
self_type previous_iterator = *this;
++(*this);
return previous_iterator;
}
bool operator==(const self_type& rhs) const { return i_ == rhs.i_; }
bool operator!=(const self_type& rhs) const { return i_ != rhs.i_; }
DispatchKey operator*() const { return static_cast<DispatchKey> (i_); }
bool operator==(const self_type& rhs) const {
return i_ == rhs.i_;
}
bool operator!=(const self_type& rhs) const {
return i_ != rhs.i_;
}
DispatchKey operator*() const {
return static_cast<DispatchKey>(i_);
}
private:
const uint64_t *data_ptr_;
const uint64_t* data_ptr_;
uint8_t i_;
};
public:
// Returns iterator to the first key in the set. If no keys are in the
// set, then will return the end iterator.
iterator begin() const { return iterator(&repr_); }
iterator begin() const {
return iterator(&repr_);
}
// We do not need to iterate beyond NumDispatchKeys so we will treat this as
// the end iterator. NumDispatchKeys will always be strictly less than 64.
iterator end() const { return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)); }
iterator end() const {
return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
}
};
C10_API std::string toString(DispatchKeySet);
@ -208,7 +222,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
});
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
//DispatchKey::AutocastCPU,
// DispatchKey::AutocastCPU,
DispatchKey::AutocastCUDA,
});
@ -224,40 +238,36 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
// backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
DispatchKey::HIP,
DispatchKey::FPGA,
DispatchKey::MSNPU,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::QuantizedCPU,
DispatchKey::QuantizedCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseHIP,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta
});
constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
{DispatchKey::HIP,
DispatchKey::FPGA,
DispatchKey::MSNPU,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::QuantizedCPU,
DispatchKey::QuantizedCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseHIP,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta});
// The set of dispatch keys that come after autograd
// n.b. this relies on the fact that AutogradOther is currently the lowest Autograd key
constexpr DispatchKeySet after_autograd_keyset = DispatchKeySet(
DispatchKeySet::FULL_AFTER,
c10::DispatchKey::AutogradOther
);
// n.b. this relies on the fact that AutogradOther is currently the lowest
// Autograd key
constexpr DispatchKeySet after_autograd_keyset =
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
// The set of dispatch keys that come after InplaceOrView
constexpr DispatchKeySet after_InplaceOrView_keyset = DispatchKeySet(
DispatchKeySet::FULL_AFTER,
c10::DispatchKey::InplaceOrView
);
constexpr DispatchKeySet after_InplaceOrView_keyset =
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::InplaceOrView);
// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);
@ -265,8 +275,8 @@ C10_API bool isBackendDispatchKey(DispatchKey t);
// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key t,
// DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// Returns a DispatchKeySet of autograd related keys mapped to backend.
@ -291,27 +301,32 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
// top of existing "backend" keys like CPU/CUDA, you need to add it
// here. At the moment, autograd keys and InplaceOrView key need this
// treatment;
return (s - autograd_dispatch_keyset_with_InplaceOrView - autocast_dispatch_keyset).highestPriorityTypeId();
return (s - autograd_dispatch_keyset_with_InplaceOrView -
autocast_dispatch_keyset)
.highestPriorityTypeId();
}
template<class T>
template <class T>
using is_not_DispatchKeySet = guts::negation<std::is_same<DispatchKeySet, T>>;
// Given a function type, constructs a function_traits type that drops the first parameter
// type if the first parameter is of type DispatchKeySet.
// NB: DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid pushing unnecessary
// arguments on the stack - see Note [ Plumbing Keys Through the Dispatcher] for details).
// If at any point in the future we need to expose this type to JIT, revisit the usage of this type alias.
// Given a function type, constructs a function_traits type that drops the first
// parameter type if the first parameter is of type DispatchKeySet. NB:
// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
// the Dispatcher] for details). If at any point in the future we need to expose
// this type to JIT, revisit the usage of this type alias.
template <class FuncType>
using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
typename guts::infer_function_traits_t<FuncType>::return_type,
typename std::conditional_t<
std::is_same<
DispatchKeySet,
typename guts::typelist::head_with_default_t<void, typename guts::infer_function_traits_t<FuncType>::parameter_types>
>::value,
guts::typelist::drop_if_nonempty_t<typename guts::infer_function_traits_t<FuncType>::parameter_types, 1>,
typename guts::infer_function_traits_t<FuncType>::parameter_types
>
>;
}
typename guts::infer_function_traits_t<FuncType>::return_type,
typename std::conditional_t<
std::is_same<
DispatchKeySet,
typename guts::typelist::head_with_default_t<
void,
typename guts::infer_function_traits_t<
FuncType>::parameter_types>>::value,
guts::typelist::drop_if_nonempty_t<
typename guts::infer_function_traits_t<FuncType>::parameter_types,
1>,
typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
} // namespace c10

View File

@ -41,18 +41,18 @@ struct Event final {
// Constructors
Event() = delete;
Event(
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: impl_{_device_type, _flag} { }
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: impl_{_device_type, _flag} {}
// Copy constructor and copy assignment operator (deleted)
Event(const Event&) = delete;
Event& operator=(const Event&) = delete;
// Move constructor and move assignment operator
Event(Event&& other) : impl_{std::move(other.impl_)} { }
Event(Event&& other) : impl_{std::move(other.impl_)} {}
Event& operator=(Event&& other) {
impl_.swap(std::move(other.impl_));
impl_.swap(std::move(other.impl_));
return *this;
}
@ -60,54 +60,62 @@ struct Event final {
~Event() = default;
// Getters
DeviceType device_type() const noexcept { return impl_.device_type(); }
DeviceIndex device_index() const noexcept { return impl_.device_index(); }
EventFlag flag() const noexcept { return impl_.flag(); }
bool was_marked_for_recording() const noexcept { return impl_.was_marked_for_recording(); }
DeviceType device_type() const noexcept {
return impl_.device_type();
}
DeviceIndex device_index() const noexcept {
return impl_.device_index();
}
EventFlag flag() const noexcept {
return impl_.flag();
}
bool was_marked_for_recording() const noexcept {
return impl_.was_marked_for_recording();
}
/**
* Calls record() if and only if record() has never been called for this event.
* Note: because Event is not thread-safe recordOnce() may call record()
* multiple times if called from multiple threads.
*/
/**
* Calls record() if and only if record() has never been called for this
* event. Note: because Event is not thread-safe recordOnce() may call
* record() multiple times if called from multiple threads.
*/
void recordOnce(const Stream& stream) {
impl_.recordOnce(stream);
}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it nofifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it nofifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(const Stream& stream) {
impl_.record(stream);
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(const Stream& stream) const {
impl_.block(stream);
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool query() const {
return impl_.query();
}
private:
private:
impl::InlineEvent<impl::VirtualGuardImpl> impl_;
};
} // c10
} // namespace c10

View File

@ -13,7 +13,7 @@ namespace c10 {
* GeneratorImpl class implementation
*/
GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set)
: device_{device_in}, key_set_(key_set) {}
: device_{device_in}, key_set_(key_set) {}
/**
* Clone this generator. Note that clone() is the only
@ -40,14 +40,15 @@ namespace detail {
* FIXME: use std::random_device with entropy information
*/
#ifndef _WIN32
static uint64_t readURandomLong()
{
static uint64_t readURandomLong() {
int randDev = open("/dev/urandom", O_RDONLY);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint64_t randValue;
TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom");
ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));
TORCH_CHECK(readBytes >= (ssize_t) sizeof(randValue), "Unable to read from /dev/urandom");
TORCH_CHECK(
readBytes >= (ssize_t)sizeof(randValue),
"Unable to read from /dev/urandom");
close(randDev);
return randValue;
}
@ -58,11 +59,11 @@ static uint64_t readURandomLong()
* /dev/urandom or the current time. For CUDA, gets random from
* std::random_device and adds a transformation on it.
*
* FIXME: The behavior in this function is from legacy code (THRandom_seed/THCRandom_seed)
* and is probably not the right thing to do, even though our tests pass.
* Figure out if tests get perturbed
* - when the same algorithm is used for all backends. Note that the current behavior is
* different for CPU, CUDA and Windows CPU.
* FIXME: The behavior in this function is from legacy code
* (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
* even though our tests pass. Figure out if tests get perturbed
* - when the same algorithm is used for all backends. Note that the current
* behavior is different for CPU, CUDA and Windows CPU.
* - when using C++11 std objects, such as std::random_device
* - when constructing a 64 bit seed properly, rather than static casting
* a 32 bit number to 64 bit.
@ -71,11 +72,13 @@ uint64_t getNonDeterministicRandom(bool is_cuda) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint64_t s;
if (!is_cuda) {
#ifdef _WIN32
s = (uint64_t)std::chrono::high_resolution_clock::now().time_since_epoch().count();
#else
s = readURandomLong();
#endif
#ifdef _WIN32
s = (uint64_t)std::chrono::high_resolution_clock::now()
.time_since_epoch()
.count();
#else
s = readURandomLong();
#endif
} else {
std::random_device rd;
// limit to 53 bits to ensure unique representation in double

View File

@ -1,52 +1,54 @@
#pragma once
#include <stdint.h>
#include <mutex>
#include <deque>
#include <atomic>
#include <deque>
#include <mutex>
#include <typeinfo>
#include <utility>
#include <c10/util/Exception.h>
#include <c10/util/C++17.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/python_stub.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/python_stub.h>
/**
* Note [Generator]
* ~~~~~~~~~~~~~~~~
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
* generate a seemingly random sequence of numbers, that may be later be used in creating
* a random distribution. Such an engine almost always maintains a state and requires a
* seed to start off the creation of random numbers. Often times, users have
* found it beneficial to be able to explicitly create, retain, and destroy
* PRNG states and also be able to have control over the seed value.
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm
* to generate a seemingly random sequence of numbers, that may be later be used
* in creating a random distribution. Such an engine almost always maintains a
* state and requires a seed to start off the creation of random numbers. Often
* times, users have found it beneficial to be able to explicitly create,
* retain, and destroy PRNG states and also be able to have control over the
* seed value.
*
* A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
* For instance, it does so by letting users seed a PRNG engine, fork the state of the
* engine, etc.
* A Generator in ATen gives users the ability to read, write and modify a PRNG
* engine. For instance, it does so by letting users seed a PRNG engine, fork
* the state of the engine, etc.
*
* By default, there is one generator per device, and a device's generator is
* lazily created. A user can use the torch.Generator() api to create their own generator.
* Currently torch.Generator() can only create a CPUGeneratorImpl.
* lazily created. A user can use the torch.Generator() api to create their own
* generator. Currently torch.Generator() can only create a CPUGeneratorImpl.
*/
/**
* Note [Acquire lock when using random generators]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Generator and its derived classes are NOT thread-safe. Please note that most of the
* places where we have inserted locking for generators are historically based, and we
* haven't actually checked that everything is truly thread safe (and it probably isn't).
* Please use the public mutex_ when using any methods from these classes, except for the
* read-only methods. You can learn about the usage by looking into the unittests
* (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
* Generator and its derived classes are NOT thread-safe. Please note that most
* of the places where we have inserted locking for generators are historically
* based, and we haven't actually checked that everything is truly thread safe
* (and it probably isn't). Please use the public mutex_ when using any methods
* from these classes, except for the read-only methods. You can learn about the
* usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp)
* and other places where we have used lock_guard.
*
* TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
* them non-thread safe and instead making the generator state splittable, to accommodate
* forks into other threads).
* TODO: Look into changing the threading semantics of Generators in ATen (e.g.,
* making them non-thread safe and instead making the generator state
* splittable, to accommodate forks into other threads).
*/
namespace c10 {
@ -79,7 +81,9 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
// See Note [Acquire lock when using random generators]
std::mutex mutex_;
DispatchKeySet key_set() const { return key_set_; }
DispatchKeySet key_set() const {
return key_set_;
}
inline void set_pyobj(PyObject* pyobj) noexcept {
pyobj_ = pyobj;
@ -89,12 +93,12 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
return pyobj_;
}
protected:
Device device_;
DispatchKeySet key_set_;
PyObject* pyobj_ = nullptr;
protected:
Device device_;
DispatchKeySet key_set_;
PyObject* pyobj_ = nullptr;
virtual GeneratorImpl* clone_impl() const = 0;
virtual GeneratorImpl* clone_impl() const = 0;
};
namespace detail {

View File

@ -27,4 +27,4 @@ struct TORCH_API NoGradGuard : public AutoGradMode {
NoGradGuard() : AutoGradMode(/*enabled=*/false) {}
};
}
} // namespace c10

View File

@ -6,7 +6,8 @@ namespace c10 {
thread_local bool InferenceMode_enabled = false;
// Invariant:
// is_enabled() == !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
// is_enabled() ==
// !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
// InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor)
// so it worths a separate TLS to skip the DispatchKeySet check.
bool InferenceMode::is_enabled() {

View File

@ -1,8 +1,8 @@
#pragma once
#include <c10/core/GradMode.h>
#include <c10/macros/Macros.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Macros.h>
namespace c10 {
@ -10,13 +10,14 @@ namespace c10 {
// construction, and sets it back to the original value upon destruction.
struct TORCH_API InferenceMode {
// Note [Expected TLS state in InferenceMode]:
// InferenceMode: InplaceOrView not in raw_local_dispatch_key_set.included(),
// InferenceMode: InplaceOrView not in
// raw_local_dispatch_key_set.included(),
// Autograd in raw_local_dispatch_key_set.excluded()
// GradMode is disabled.
// NormalMode: InplaceOrView in raw_local_dispatch_key_set.included(),
// Autograd not in raw_local_dispatch_key_set.excluded()
// GradMode is enabled by default unless toggled manually through
// other APIs, e.g. NoGradGuard.
// GradMode is enabled by default unless toggled manually
// through other APIs, e.g. NoGradGuard.
//
// Invariant:
// - InplaceOrView is never in the excluded set
@ -25,8 +26,8 @@ struct TORCH_API InferenceMode {
//
// 1. Why do we put InplaceOrView in included set outside InferenceMode?
//
// Inplace update to inference tensor outside InferenceMode is not allowed.
// See Note [Inplace update inference tensor] for more details.
// Inplace update to inference tensor outside InferenceMode is not
// allowed. See Note [Inplace update inference tensor] for more details.
// Without going through InplaceOrView kernel, we cannot throw error
// for `inference_tensor.add_(1)` case.
//
@ -48,14 +49,17 @@ struct TORCH_API InferenceMode {
// version of NoGradGuard. All runtime checks using GradMode::is_enabled()
// are applicable to InferenceMode as well, e.g.
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
InferenceMode(bool enabled=true): prev_mode(InferenceMode::is_enabled()),
prev_keyset(c10::impl::tls_local_dispatch_key_set()),
grad_mode(at::AutoGradMode(!enabled)) {
InferenceMode(bool enabled = true)
: prev_mode(InferenceMode::is_enabled()),
prev_keyset(c10::impl::tls_local_dispatch_key_set()),
grad_mode(at::AutoGradMode(!enabled)) {
set_enabled(enabled);
DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
DispatchKeySet excluded = enabled
? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
c10::impl::PODLocalDispatchKeySet cur_keyset;
cur_keyset.set_included(included);
cur_keyset.set_excluded(excluded);
@ -71,9 +75,9 @@ struct TORCH_API InferenceMode {
// ThreadLocalState.cpp.
static void set_enabled(bool enabled);
private:
bool prev_mode;
c10::impl::LocalDispatchKeySet prev_keyset;
at::AutoGradMode grad_mode;
private:
bool prev_mode;
c10::impl::LocalDispatchKeySet prev_keyset;
at::AutoGradMode grad_mode;
};
} // namespace c10

View File

@ -1,8 +1,8 @@
#pragma once
#include <c10/core/Backend.h>
#include <c10/util/Exception.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <iostream>
@ -17,14 +17,20 @@
// should be in channels_last format
//
// Contiguous:
// Regardless of input tensors format, the output should be contiguous Tensor.
// Regardless of input tensors format, the output should be contiguous
// Tensor.
//
// ChannelsLast:
// Regardless of input tensors format, the output should be in channels_last format.
// Regardless of input tensors format, the output should be in channels_last
// format.
namespace c10 {
enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
enum class MemoryFormat : int8_t {
Contiguous,
Preserve,
ChannelsLast,
ChannelsLast3d
};
// If you are seeing this, it means that this call site was not checked if
// the memory format could be preserved, and it was switched to old default
@ -52,7 +58,8 @@ inline std::ostream& operator<<(
}
}
// Note: Hardcoded the channel last stride indices here to get better performance
// Note: Hardcoded the channel last stride indices here to get better
// performance
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
switch (sizes.size()) {
@ -68,7 +75,8 @@ inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(false, "ChannelsLast2d doesn't support size ", sizes.size());
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast2d doesn't support size ", sizes.size());
}
}
@ -89,7 +97,8 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(false, "ChannelsLast3d doesn't support size ", sizes.size());
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast3d doesn't support size ", sizes.size());
}
}
@ -100,12 +109,16 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
// will be a constant array and we can access it using constant index number,
// the compiler will fully unroll the loop on strides indices to gain a better
// performance.
// 2. No error check in helper function, caller ensures the correctness of the input
// 3. All helper functions have similar comments, only 1st helper function is commented here.
inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArrayRef strides) {
// 2. No error check in helper function, caller ensures the correctness of the
// input
// 3. All helper functions have similar comments, only 1st helper function is
// commented here.
inline bool is_channels_last_strides_2d_s4(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
// special case for trivial C dimension. default to NCHW
if (strides[1]==0) {
if (strides[1] == 0) {
return false;
}
// loop strides indices
@ -121,8 +134,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr
// N111 tensor with identical strides for size 1 dimension;
// Two cases could lead us here:
// a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
// b. N11W contiguous Tensor sliced on the W-dimension. ([N,1,1,1]@[W,W,W,W])
if (d==0 && min==strides[1]) {
// b. N11W contiguous Tensor sliced on the W-dimension.
// ([N,1,1,1]@[W,W,W,W])
if (d == 0 && min == strides[1]) {
return false;
}
// This is necessary to:
@ -140,7 +154,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr
return true;
}
inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArrayRef strides) {
inline bool is_channels_last_strides_3d_s5(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
if (strides[1] == 0) {
return false;
@ -209,10 +225,13 @@ inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArr
// issues in our tests.
//
// We use Channels Last 2d as an example above.
// This is a general problem for all the is_channels_last_strides_xd implementation.
// Please check the helper functions (is_channels_last_strides_*d_s*) for more details.
// This is a general problem for all the is_channels_last_strides_xd
// implementation. Please check the helper functions
// (is_channels_last_strides_*d_s*) for more details.
inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayRef strides) {
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
switch (sizes.size()) {
case 4:
return is_channels_last_strides_2d_s4(sizes, strides);
@ -224,7 +243,9 @@ inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayR
}
}
inline bool is_channels_last_strides_3d(const IntArrayRef sizes, const IntArrayRef strides) {
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
switch (sizes.size()) {
case 5:
return is_channels_last_strides_3d_s5(sizes, strides);

View File

@ -31,9 +31,7 @@ inline std::string toString(QEngine qengine) {
return "QNNPACK";
default:
TORCH_CHECK(
false,
"Unrecognized Quantized Engine: ",
static_cast<int>(qengine));
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
}
}

View File

@ -24,12 +24,13 @@ constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
constexpr auto kPerChannelAffineFloatQParams = QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
constexpr auto kPerChannelAffineFloatQParams =
QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
constexpr int COMPILE_TIME_NUM_QSCHEMES =
static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
inline std::string toString(QScheme qscheme) {
switch(qscheme) {
switch (qscheme) {
case kPerTensorAffine:
return "per_tensor_affine";
case kPerChannelAffine:

View File

@ -3,7 +3,9 @@
namespace c10 {
Scalar Scalar::operator-() const {
TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not supported.");
TORCH_CHECK(
!isBoolean(),
"torch boolean negative, the `-` operator, is not supported.");
if (isFloatingPoint()) {
return Scalar(-v.d);
} else if (isComplex()) {
@ -31,4 +33,4 @@ Scalar Scalar::log() const {
}
}
} // namespace c10
} // namespace c10

View File

@ -4,8 +4,8 @@
#include <stdint.h>
#include <stdexcept>
#include <string>
#include <utility>
#include <type_traits>
#include <utility>
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
@ -26,8 +26,8 @@ class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) { }
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
@ -45,18 +45,18 @@ class C10_API Scalar {
v.i = convert<int64_t, bool>(vv);
}
#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert<type, c10::complex<double>>( \
v.z, #type); \
} if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else { \
return checked_convert<type, int64_t>(v.i, #type); \
} \
#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert<type, c10::complex<double>>(v.z, #type); \
} \
if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else { \
return checked_convert<type, int64_t>(v.i, #type); \
} \
}
// TODO: Support ComplexHalf accessor
@ -71,7 +71,8 @@ class C10_API Scalar {
return Tag::HAS_d == tag;
}
C10_DEPRECATED_MESSAGE("isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
C10_DEPRECATED_MESSAGE(
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
bool isIntegral() const {
return Tag::HAS_i == tag;
}
@ -90,7 +91,9 @@ class C10_API Scalar {
Scalar conj() const;
Scalar log() const;
template<typename T, typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
template <
typename T,
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
if (isComplex()) {
auto val = v.z;
@ -105,7 +108,9 @@ class C10_API Scalar {
}
}
template<typename T, typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
template <
typename T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
if (isComplex()) {
return v.z == num;
@ -142,26 +147,30 @@ class C10_API Scalar {
}
private:
template<typename T,
typename std::enable_if<std::is_integral<T>::value && ! std::is_same<T, bool>::value, bool>::type* =
nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_i) {
v.i = convert<decltype(v.i), T>(vv);
}
template <
typename T,
typename std::enable_if<
std::is_integral<T>::value && !std::is_same<T, bool>::value,
bool>::type* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_i) {
v.i = convert<decltype(v.i), T>(vv);
}
template<typename T,
typename std::enable_if<!std::is_integral<T>::value && !c10::is_complex<T>::value, bool>::type* =
nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_d) {
v.d = convert<decltype(v.d), T>(vv);
}
template <
typename T,
typename std::enable_if<
!std::is_integral<T>::value && !c10::is_complex<T>::value,
bool>::type* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_d) {
v.d = convert<decltype(v.d), T>(vv);
}
template<typename T,
typename std::enable_if<c10::is_complex<T>::value, bool>::type* =
nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_z) {
v.z = convert<decltype(v.z), T>(vv);
}
template <
typename T,
typename std::enable_if<c10::is_complex<T>::value, bool>::type* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_z) {
v.z = convert<decltype(v.z), T>(vv);
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
@ -172,7 +181,7 @@ class C10_API Scalar {
double d;
int64_t i;
c10::complex<double> z;
v_t(){} // default constructor
v_t() {} // default constructor
} v;
};
@ -182,10 +191,10 @@ inline T Scalar::to() const {
throw std::runtime_error("to() cast to unexpected type.");
}
#define DEFINE_TO(T, name) \
template <> \
inline T Scalar::to<T>() const { \
return to##name(); \
#define DEFINE_TO(T, name) \
template <> \
inline T Scalar::to<T>() const { \
return to##name(); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
#undef DEFINE_TO

View File

@ -1,14 +1,14 @@
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/complex.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/Optional.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
#include <c10/util/BFloat16.h>
#include <c10/util/quint4x2.h>
#include <c10/util/Optional.h>
#include <c10/util/quint8.h>
#include <complex>
#include <cstdint>
@ -44,7 +44,6 @@ namespace c10 {
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
@ -62,7 +61,6 @@ namespace c10 {
_(bool, Bool) \
_(at::BFloat16, BFloat16)
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
@ -71,7 +69,8 @@ enum class ScalarType : int8_t {
NumOptions
};
constexpr uint16_t NumScalarTypes = static_cast<uint16_t>(ScalarType::NumOptions);
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {
@ -80,20 +79,20 @@ namespace impl {
template <c10::ScalarType N>
struct ScalarTypeToCPPType;
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
template<> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* cant pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
template <> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* cant pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
@ -104,12 +103,12 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template<> \
struct CppTypeToScalarType<cpp_type>: \
std::integral_constant<c10::ScalarType, \
c10::ScalarType::scalar_type> \
{};
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
@ -131,47 +130,59 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(float, Float) \
_(double, Double)
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype( \
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32) \
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32) \
_(c10::quint4x2, QUInt4x2)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)
#define DEFINE_CONSTANT(_, name) \
@ -206,7 +217,8 @@ static inline size_t elementSize(ScalarType t) {
#undef CASE_ELEMENTSIZE_CASE
}
C10_DEPRECATED_MESSAGE("isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
C10_DEPRECATED_MESSAGE(
"isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
static inline bool isIntegralType(ScalarType t) {
return (
t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
@ -214,9 +226,9 @@ static inline bool isIntegralType(ScalarType t) {
}
static inline bool isIntegralType(ScalarType t, bool includeBool) {
bool isIntegral = (
t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short);
bool isIntegral =
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short);
return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral;
}
@ -235,7 +247,8 @@ static inline bool isComplexType(ScalarType t) {
static inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2;
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2;
}
static inline ScalarType toQIntType(ScalarType t) {
@ -268,20 +281,20 @@ static inline ScalarType toUnderlying(ScalarType t) {
static inline bool isSignedType(ScalarType t) {
TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types");
#define CASE_SIGNED(ctype, name) \
case ScalarType::name: \
return std::numeric_limits<ctype>::is_signed;
#define CASE_SIGNED(ctype, name) \
case ScalarType::name: \
return std::numeric_limits<ctype>::is_signed;
switch (t) {
case ScalarType::ComplexHalf:
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
return true;
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
#undef CASE_SIGNED
#undef CASE_SIGNED
}
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
@ -323,7 +336,8 @@ static inline ScalarType toComplexType(ScalarType t) {
// see tensor_attributes.rst for detailed explanation and examples
// of casting rules.
static inline bool canCast(const ScalarType from, const ScalarType to) {
// We disallow complex -> non complex, e.g., float_tensor *= complex is disallowed.
// We disallow complex -> non complex, e.g., float_tensor *= complex is
// disallowed.
if (isComplexType(from) && !isComplexType(to)) {
return false;
}
@ -333,13 +347,14 @@ static inline bool canCast(const ScalarType from, const ScalarType to) {
}
// Treat bool as a distinct "category," to be consistent with type promotion
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same category
// as `bool_tensor`, we would not promote.
// Differing categories implies `bool_tensor += 5` is disallowed.
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
// category as `bool_tensor`, we would not promote. Differing categories
// implies `bool_tensor += 5` is disallowed.
//
// NB: numpy distinguishes "unsigned" as a category to get the desired
// `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
// * We don't want the performance hit of checking the runtime sign of Scalars.
// * We don't want the performance hit of checking the runtime sign of
// Scalars.
// * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
if (from != ScalarType::Bool && to == ScalarType::Bool) {
return false;
@ -373,7 +388,8 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
}
if (isQIntType(a) || isQIntType(b)) {
TORCH_CHECK(false,
TORCH_CHECK(
false,
"promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
toString(a),
" ",
@ -385,23 +401,23 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
// corrent values for the type promotions in complex type cases.
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
/* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
/* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}

View File

@ -26,7 +26,8 @@ static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
/**
* typeMetaToScalarType(), lifted to optional
*/
static inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
static inline optional<at::ScalarType> optTypeMetaToScalarType(
optional<caffe2::TypeMeta> type_meta) {
if (!type_meta.has_value()) {
return c10::nullopt;
}

View File

@ -1,5 +1,3 @@
#include <c10/core/Storage.h>
namespace c10 {
} // namespace c10
namespace c10 {} // namespace c10

View File

@ -9,7 +9,8 @@ struct C10_API Storage {
struct use_byte_size_t {};
Storage() {}
Storage(c10::intrusive_ptr<StorageImpl> ptr) : storage_impl_(std::move(ptr)) {}
Storage(c10::intrusive_ptr<StorageImpl> ptr)
: storage_impl_(std::move(ptr)) {}
// Allocates memory buffer using given allocator and creates a storage with it
Storage(
@ -53,10 +54,14 @@ struct C10_API Storage {
}
template <typename T>
T* data() const { return storage_impl_->data<T>(); }
T* data() const {
return storage_impl_->data<T>();
}
template <typename T>
T* unsafe_data() const { return storage_impl_->unsafe_data<T>(); }
T* unsafe_data() const {
return storage_impl_->unsafe_data<T>();
}
// TODO: remove later
void set_nbytes(size_t size_bytes) const {
@ -134,7 +139,8 @@ struct C10_API Storage {
size_t capacity,
DeleterFnPtr d = nullptr) {
if (!storage_impl_.unique()) {
TORCH_CHECK(false,
TORCH_CHECK(
false,
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d);
@ -144,7 +150,8 @@ struct C10_API Storage {
at::DataPtr&& data_ptr,
size_t capacity) {
if (!storage_impl_.unique()) {
TORCH_CHECK(false,
TORCH_CHECK(
false,
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(

View File

@ -20,13 +20,13 @@ namespace c10 {
// but a lot of things won't work correctly, including:
//
// - An ordinary deleter on such a storage is wrong, because normal deleters
// assume unique ownership, but if you have two storages at the same data, that
// implies there is some sort of shared ownership. So your deleter would have to
// actually be internally doing some sort of refcount thing
// assume unique ownership, but if you have two storages at the same data,
// that implies there is some sort of shared ownership. So your deleter would
// have to actually be internally doing some sort of refcount thing
// - Deepcopy in Python side relies on storage equality and not data pointer
// equality; so if there are two separate storages pointing to the same data,
// the data will actually get duplicated in that case (one data ptr before, two
// data ptrs after)
// the data will actually get duplicated in that case (one data ptr before,
// two data ptrs after)
// - Version counts won't work correctly, because we do all VC tracking at the
// level of storages (unless you explicitly disconnect the VC with detach);
// mutation because data pointers are the same are totally untracked

View File

@ -55,10 +55,11 @@ using StreamId = int32_t;
* wrapper classes which provide this functionality, e.g., CUDAStream.
*/
class Stream final {
private:
private:
Device device_;
StreamId id_;
public:
public:
enum Unsafe { UNSAFE };
enum Default { DEFAULT };
@ -69,16 +70,13 @@ public:
/// we don't require backends to give any guarantees about non-zero
/// StreamIds; they are welcome to allocate in whatever way they like.
explicit Stream(Unsafe, Device device, StreamId id)
: device_(device)
, id_(id) {}
: device_(device), id_(id) {}
/// Construct the default stream of a Device. The default stream is
/// NOT the same as the current stream; default stream is a fixed stream
/// that never changes, whereas the current stream may be changed by
/// StreamGuard.
explicit Stream(Default, Device device)
: device_(device)
, id_(0) {}
explicit Stream(Default, Device device) : device_(device), id_(0) {}
bool operator==(const Stream& other) const noexcept {
return this->device_ == other.device_ && this->id_ == other.id_;
@ -87,10 +85,18 @@ public:
return !(*this == other);
}
Device device() const noexcept { return device_; }
DeviceType device_type() const noexcept { return device_.type(); }
DeviceIndex device_index() const noexcept { return device_.index(); }
StreamId id() const noexcept { return id_; }
Device device() const noexcept {
return device_;
}
DeviceType device_type() const noexcept {
return device_.type();
}
DeviceIndex device_index() const noexcept {
return device_.index();
}
StreamId id() const noexcept {
return id_;
}
// Enqueues a wait instruction in the stream's work queue.
// This instruction is a no-op unless the event is marked
@ -116,10 +122,10 @@ public:
static_assert(sizeof(StreamId) == 4, "DeviceIndex is not 32-bit");
// Concat these together into a 64-bit integer
// See Note [Hazard when concatenating signed integers]
uint64_t bits =
static_cast<uint64_t>(static_cast<uint8_t>(device_type())) << 48
| static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 32
| static_cast<uint64_t>(static_cast<uint32_t>(id()));
uint64_t bits = static_cast<uint64_t>(static_cast<uint8_t>(device_type()))
<< 48 |
static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 32 |
static_cast<uint64_t>(static_cast<uint32_t>(id()));
return bits;
}
@ -145,10 +151,10 @@ C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
} // namespace c10
namespace std {
template <>
struct hash<c10::Stream> {
size_t operator()(c10::Stream s) const noexcept {
return std::hash<uint64_t>{}(s.pack());
}
};
template <>
struct hash<c10::Stream> {
size_t operator()(c10::Stream s) const noexcept {
return std::hash<uint64_t>{}(s.pack());
}
};
} // namespace std

View File

@ -47,7 +47,9 @@ struct StreamGuard {
/// WARNING: reset_stream does NOT preserve previously set streams on
/// different devices. If you need to set streams on multiple devices
/// on , use MultiStreamGuard instead.
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the stream that was set at the time the guard was constructed.
Stream original_stream() const {
@ -62,13 +64,17 @@ struct StreamGuard {
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const { return guard_.current_device(); }
Device current_device() const {
return guard_.current_device();
}
/// Returns the device that was set at the most recent reset_stream(),
/// or otherwise the device at construction time.
Device original_device() const { return guard_.original_device(); }
Device original_device() const {
return guard_.original_device();
}
private:
private:
c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_;
};
@ -88,7 +94,8 @@ struct OptionalStreamGuard {
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit OptionalStreamGuard(optional<Stream> stream_opt) : guard_(stream_opt) {}
explicit OptionalStreamGuard(optional<Stream> stream_opt)
: guard_(stream_opt) {}
/// Copy is disallowed
OptionalStreamGuard(const OptionalStreamGuard&) = delete;
@ -105,21 +112,30 @@ struct OptionalStreamGuard {
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Initializes the guard if it was not previously initialized.
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the stream that was set at the time the guard was most recently
/// initialized, or nullopt if the guard is uninitialized.
optional<Stream> original_stream() const { return guard_.original_stream(); }
optional<Stream> original_stream() const {
return guard_.original_stream();
}
/// Returns the most recent stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
optional<Stream> current_stream() const { return guard_.current_stream(); }
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
optional<Stream> current_stream() const {
return guard_.current_stream();
}
/// Restore the original device and stream, resetting this guard to uninitialized state.
void reset() { guard_.reset(); }
/// Restore the original device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
private:
c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_;
};
@ -142,7 +158,7 @@ struct MultiStreamGuard {
// See Note [Move assignment for RAII guards is tricky]
MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete;
private:
private:
c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_;
};

View File

@ -1,10 +1,10 @@
#include <c10/core/TensorImpl.h>
#include <c10/core/Backend.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Optional.h>
#include <c10/core/InferenceMode.h>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_bool(
@ -21,7 +21,7 @@ C10_DEFINE_int64(
namespace c10 {
const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
"is not allowed on a Tensor created from .data or .detach().\n"
"If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
"without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
@ -32,7 +32,8 @@ const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
" x.set_(y)";
at::Tensor& TensorImpl::mutable_grad() {
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
return autograd_meta_->mutable_grad();
}
@ -43,18 +44,26 @@ const at::Tensor& TensorImpl::grad() const {
// is not so easy to fix right now because the mutable counterpart of
// this function must keep working so that "x.grad() = ..." keeps working
// (part of public API).
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
if (!autograd_meta_)
return impl::GetAutogradMetaFactory()->undefined_tensor();
return autograd_meta_->grad();
}
const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) const {
const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self)
const {
// See TensorImpl::grad() above for explanation about the line below
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
if (!autograd_meta_)
return impl::GetAutogradMetaFactory()->undefined_tensor();
return autograd_meta_->fw_grad(level, self);
}
void TensorImpl::_set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
void TensorImpl::_set_fw_grad(
const at::Tensor& new_grad,
const at::Tensor& self,
uint64_t level,
bool is_inplace_op) {
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
}
@ -63,7 +72,11 @@ TensorImpl::TensorImpl(
DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
// Use std::forward to suppress static analyzer false positive.
: TensorImpl(std::forward<Storage>(storage), key_set, data_type, storage.device()) {}
: TensorImpl(
std::forward<Storage>(storage),
key_set,
data_type,
storage.device()) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
@ -84,23 +97,29 @@ TensorImpl::TensorImpl(
}
}
TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional<c10::Device> device_opt)
TensorImpl::TensorImpl(
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
// NOLINTNEXTLINE(performance-move-const-arg)
: TensorImpl({}, key_set, data_type, std::move(device_opt)) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
: storage_(std::move(storage)),
storage_offset_(0),
numel_(0),
data_type_(data_type),
device_opt_(device_opt) {
init_bitfields();
if (!key_set.empty()) {
TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value());
TORCH_INTERNAL_ASSERT(
data_type == ScalarType::Undefined || device_opt_.has_value());
// UndefinedTensorImpl is a singleton, so we skip logging it
C10_LOG_API_USAGE_ONCE("tensor.create");
}
@ -115,12 +134,13 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
// Inference tensor doesn't have autograd related keys.
if (inference_mode) {
// See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys.
// Normally key_set only contains backend keys but we do the substraction
// here to make sure.
// See Note [Expected TLS state in InferenceMode] for why we exclude
// Autograd & InplaceOrView keys. Normally key_set only contains backend
// keys but we do the substraction here to make sure.
key_set_ = key_set - c10::autograd_dispatch_keyset_with_InplaceOrView;
} else {
// TODO: Ideally we only add AutogradBackend key when the tensor requires grad.
// TODO: Ideally we only add AutogradBackend key when the tensor requires
// grad.
// See Note [Dream: skip VariableType kernel when requires_grad=false]
key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
}
@ -130,8 +150,8 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
version_counter_ = VariableVersion(/*version=*/0);
}
// we would also like to check that non-cpu devices have an index, but some Caffe2 operators create
// Storages with default devices.
// we would also like to check that non-cpu devices have an index, but some
// Caffe2 operators create Storages with default devices.
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -151,15 +171,14 @@ void TensorImpl::HandleResize() {
if (reserved_) {
// If tensor is reserved then don't claim its memeory unless nbytes()
// is smaller than new size
reset_tensor = storage_.nbytes() <
(storage_offset_ + numel_) * data_type_.itemsize();
reset_tensor =
storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize();
} else {
reset_tensor = storage_.nbytes() <
(storage_offset_ + numel_) * data_type_.itemsize() ||
!FLAGS_caffe2_keep_on_shrink ||
storage_.nbytes() -
(storage_offset_ + numel_) * data_type_.itemsize() >
static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
(storage_offset_ + numel_) * data_type_.itemsize() ||
!FLAGS_caffe2_keep_on_shrink ||
storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() >
static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
}
if (reset_tensor && storage_initialized()) {
@ -190,20 +209,19 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
case 4:
{
int64_t expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
case 4: {
int64_t expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
}
return true;
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
@ -218,20 +236,19 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const {
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
case 5:
{
int64_t expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
case 5: {
int64_t expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
}
return true;
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
@ -242,34 +259,38 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const {
}
bool TensorImpl::compute_strides_like_channels_last_2d() const {
return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides());
return is_channels_last_strides_2d(
TensorImpl::sizes(), TensorImpl::strides());
}
bool TensorImpl::compute_strides_like_channels_last_3d() const {
return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides());
return is_channels_last_strides_3d(
TensorImpl::sizes(), TensorImpl::strides());
}
bool TensorImpl::compute_non_overlapping_and_dense() const {
if (dim() == 1) {
return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1;
return sizes_and_strides_.size_at_unchecked(0) < 2 ||
sizes_and_strides_.stride_at_unchecked(0) == 1;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
SmallVector<int64_t,5> perm;
SmallVector<int64_t, 5> perm;
perm.resize(dim());
for (int64_t i = 0; i < dim(); i ++) {
for (int64_t i = 0; i < dim(); i++) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
return false;
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
return true;
}
return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b);
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
return false;
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
return true;
}
return sizes_and_strides_.stride_at_unchecked(a) <
sizes_and_strides_.stride_at_unchecked(b);
});
auto require_stride = 1;
for (int64_t i = 0; i < dim(); i ++) {
for (int64_t i = 0; i < dim(); i++) {
const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]);
if (size_perm_i < 2) {
return true;
@ -312,16 +333,23 @@ bool TensorImpl::has_storage() const {
#endif
void TensorImpl::throw_storage_access_error() const {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot access storage of ", tensorimpl_type_name());
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Cannot access storage of ", tensorimpl_type_name());
}
bool TensorImpl::is_contiguous_nondefault_policy_impl(at::MemoryFormat memory_format) const {
if (has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
bool TensorImpl::is_contiguous_nondefault_policy_impl(
at::MemoryFormat memory_format) const {
if (has_contiguity_ ==
static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Tensors of type ", tensorimpl_type_name(),
false,
"Tensors of type ",
tensorimpl_type_name(),
" do not have is_contiguous");
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
has_contiguity_ ==
static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
return is_contiguous_custom(memory_format);
}
}
@ -343,10 +371,11 @@ at::DataPtr PlacementDeleteContext::makeDataPtr(
size_t size,
at::Device device) {
auto* ptr = data_ptr.get();
return {ptr,
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
&deletePlacementDeleteContext,
device};
return {
ptr,
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
&deletePlacementDeleteContext,
device};
}
// NOLINTNEXTLINE(modernize-use-equals-default)
@ -359,10 +388,14 @@ AutogradMetaInterface::~AutogradMetaInterface() {}
// used in C++ frontend. Forbidding it inside InferenceMode will force users
// to delete these setter code in their code which is not ideal.
void TensorImpl::set_requires_grad(bool requires_grad) {
TORCH_CHECK(!(requires_grad && is_inference_tensor() && !c10::InferenceMode::is_enabled()),
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
if (!requires_grad && !autograd_meta_) return;
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
TORCH_CHECK(
!(requires_grad && is_inference_tensor() &&
!c10::InferenceMode::is_enabled()),
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
if (!requires_grad && !autograd_meta_)
return;
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
// NB: In principle, setting requires_grad to false could result in
// the AutogradMeta becoming equal to a default constructed state,
// in which case we could apply the nullptr AutogradMeta optimization
@ -376,11 +409,13 @@ void TensorImpl::set_requires_grad(bool requires_grad) {
}
bool TensorImpl::requires_grad() const {
if (!autograd_meta_) return false;
if (!autograd_meta_)
return false;
return autograd_meta_->requires_grad();
}
void TensorImpl::set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
void TensorImpl::set_autograd_meta(
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
// NB: autograd_meta may be null! That just means it's the default
// constructor
autograd_meta_ = std::move(autograd_meta);
@ -396,7 +431,9 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
bool allow_tensor_metadata_change) const {
auto impl = c10::make_intrusive<TensorImpl>(
// No need to populate Storage; copy_tensor_metadata will do it for us.
key_set_, data_type_, device_opt_);
key_set_,
data_type_,
device_opt_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
@ -412,7 +449,9 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
bool allow_tensor_metadata_change) const {
auto impl = c10::make_intrusive<TensorImpl>(
// No need to populate Storage; copy_tensor_metadata will do it for us.
key_set_, data_type_, device_opt_);
key_set_,
data_type_,
device_opt_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
@ -435,15 +474,19 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
dest_impl->key_set_ = src_impl->key_set_;
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
dest_impl->is_channels_last_3d_contiguous_ = src_impl->is_channels_last_3d_contiguous_;
dest_impl->is_channels_last_contiguous_ =
src_impl->is_channels_last_contiguous_;
dest_impl->is_channels_last_3d_contiguous_ =
src_impl->is_channels_last_3d_contiguous_;
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_;
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
dest_impl->is_non_overlapping_and_dense_ =
src_impl->is_non_overlapping_and_dense_;
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
dest_impl->storage_access_should_throw_ = src_impl->storage_access_should_throw_;
dest_impl->storage_access_should_throw_ =
src_impl->storage_access_should_throw_;
if (src_impl->named_tensor_meta_ != nullptr) {
dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone();
}
@ -454,9 +497,11 @@ void TensorImpl::copy_tensor_metadata(
TensorImpl* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
copy_tensor_metadata_except_version_counter(
src_impl, dest_impl, allow_tensor_metadata_change);
// TODO: In the ideal end state, it's okay to set disabled version_counter
// on inference tensor since it's a no-op. This requires refactor on call sites.
// on inference tensor since it's a no-op. This requires refactor on call
// sites.
if (!dest_impl->is_inference_tensor()) {
dest_impl->set_version_counter(version_counter);
}
@ -467,7 +512,8 @@ void TensorImpl::copy_tensor_metadata(
TensorImpl* dest_impl,
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
copy_tensor_metadata_except_version_counter(
src_impl, dest_impl, allow_tensor_metadata_change);
if (!dest_impl->is_inference_tensor()) {
dest_impl->set_version_counter(std::move(version_counter));
}
@ -478,13 +524,15 @@ namespace impl {
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
AutogradMetaFactory* meta_factory = nullptr;
}
} // namespace
void SetAutogradMetaFactory(AutogradMetaFactory* factory) {
meta_factory = factory;
}
AutogradMetaFactory* GetAutogradMetaFactory() {
TORCH_CHECK(meta_factory, "Support for autograd has not been loaded; have you linked against libtorch.so?")
TORCH_CHECK(
meta_factory,
"Support for autograd has not been loaded; have you linked against libtorch.so?")
return meta_factory;
}

File diff suppressed because it is too large Load Diff

View File

@ -17,18 +17,20 @@ namespace c10 {
// internal state and what its getters will return.
std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
auto print = [&](const char *label, auto prop, bool has_prop) {
auto print = [&](const char* label, auto prop, bool has_prop) {
stream << label << std::boolalpha << prop << (has_prop ? "" : " (default)");
};
print("TensorOptions(dtype=", options.dtype(), options.has_dtype());
print(", device=", options.device(), options.has_device());
print(", layout=", options.layout(), options.has_layout());
print(", requires_grad=", options.requires_grad(), options.has_requires_grad());
print(", pinned_memory=", options.pinned_memory(), options.has_pinned_memory());
print(
", requires_grad=", options.requires_grad(), options.has_requires_grad());
print(
", pinned_memory=", options.pinned_memory(), options.has_pinned_memory());
// note: default-supplying memory_format() getter not provided; no canonical default
// note: default-supplying memory_format() getter not provided; no canonical
// default
stream << ", memory_format=";
if (options.has_memory_format()) {
stream << *options.memory_format_opt();

View File

@ -1,17 +1,17 @@
#pragma once
#include <c10/core/DefaultDtype.h>
#include <c10/core/Backend.h>
#include <c10/core/DefaultDtype.h>
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Device.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Optional.h>
#include <c10/util/C++17.h>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/Optional.h>
#include <cstddef>
#include <iosfwd>
@ -19,14 +19,18 @@
namespace c10 {
DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device);
DispatchKey computeDispatchKey(
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device);
inline ScalarType dtype_or_default(c10::optional<ScalarType> dtype) {
return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();});
return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
}
inline caffe2::TypeMeta dtype_or_default(c10::optional<caffe2::TypeMeta> dtype) {
return value_or_else(dtype, [] {return get_default_dtype();});
inline caffe2::TypeMeta dtype_or_default(
c10::optional<caffe2::TypeMeta> dtype) {
return value_or_else(dtype, [] { return get_default_dtype(); });
}
inline Layout layout_or_default(c10::optional<Layout> layout) {
@ -34,7 +38,7 @@ inline Layout layout_or_default(c10::optional<Layout> layout) {
}
inline Device device_or_default(c10::optional<Device> device) {
return value_or_else(device, [] {return Device(kCPU);});
return value_or_else(device, [] { return Device(kCPU); });
}
inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
@ -65,7 +69,8 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
/// at::dtype(at::kInt)
///
/// Additionally, anywhere a TensorOptions is expected, you can directly
/// pass at::kCUDA / at::kInt, and it will implicitly convert to a TensorOptions.
/// pass at::kCUDA / at::kInt, and it will implicitly convert to a
/// TensorOptions.
///
/// Here are some recommended ways to create a 2x2 tensor of zeros
/// with certain properties. These all *implicitly* make use of
@ -108,7 +113,8 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
/// }
///
/// template <typename... Args,
/// typename = std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
/// typename = std::enable_if_t<std::is_constructible<Device,
/// Args&&...>::value>>
/// /* implicit */ TensorOptions(Args&&... args)
/// : TensorOptions(Device(std::forward<Args>(args)...)) {}
///
@ -121,20 +127,21 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
/// To get around this, we templatize the `Device` constructor. Since overload
/// resolution is done before template resolution, our problem is solved.
DispatchKey computeDispatchKey(optional<ScalarType> dtype, optional<Layout> layout, optional<Device> device);
DispatchKey computeDispatchKey(
optional<ScalarType> dtype,
optional<Layout> layout,
optional<Device> device);
struct C10_API TensorOptions {
TensorOptions()
: requires_grad_(false)
, pinned_memory_(false)
, has_device_(false)
, has_dtype_(false)
, has_layout_(false)
, has_requires_grad_(false)
, has_pinned_memory_(false)
, has_memory_format_(false)
{}
: requires_grad_(false),
pinned_memory_(false),
has_device_(false),
has_dtype_(false),
has_layout_(false),
has_requires_grad_(false),
has_pinned_memory_(false),
has_memory_format_(false) {}
/// Constructs a `TensorOptions` object with the given layout.
/* implicit */ TensorOptions(Layout layout) : TensorOptions() {
@ -143,8 +150,9 @@ struct C10_API TensorOptions {
/// Constructs a `TensorOptions` object with the given device.
/// See NOTE [ TensorOptions Constructors ] on why this is templatized.
template<typename T,
typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>>
template <
typename T,
typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>>
/* implicit */ TensorOptions(T&& device) : TensorOptions() {
this->set_device(std::forward<T>(device));
}
@ -157,10 +165,12 @@ struct C10_API TensorOptions {
/// NB: Ideally we only allow implicit constructors here. But there is no easy
/// way to detect them. So we have this one that allows explicit
/// constructors too.
template <typename... Args,
typename = std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
/* implicit */ TensorOptions(Args&&... args)
: TensorOptions(Device(std::forward<Args>(args)...)) {}
template <
typename... Args,
typename =
std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
/* implicit */ TensorOptions(Args&&... args)
: TensorOptions(Device(std::forward<Args>(args)...)) {}
/// Constructs a `TensorOptions` object with the given dtype.
/* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
@ -179,7 +189,8 @@ struct C10_API TensorOptions {
/// Return a copy of `TensorOptions` with `device` set to the given one, or
/// cleared if `device` is `nullopt`.
C10_NODISCARD TensorOptions device(c10::optional<Device> device) const noexcept {
C10_NODISCARD TensorOptions
device(c10::optional<Device> device) const noexcept {
TensorOptions r = *this;
r.set_device(device);
return r;
@ -188,9 +199,10 @@ struct C10_API TensorOptions {
/// Return a copy of `TensorOptions` with `device` set to the given one.
/// (This overload ensures that variadic template c10::optional constructor
/// for Device work correctly.)
template<typename ... Args>
template <typename... Args>
C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
return device(c10::optional<Device>(c10::in_place, std::forward<Args>(args)...));
return device(
c10::optional<Device>(c10::in_place, std::forward<Args>(args)...));
}
/// Return a copy of `TensorOptions`, but with device set to CUDA, and the
@ -198,19 +210,22 @@ struct C10_API TensorOptions {
///
/// TODO: This function encourages bad behavior (assuming CUDA is
/// the only device that matters). Get rid of it / rename it.
C10_NODISCARD TensorOptions device_index(int16_t device_index) const noexcept {
C10_NODISCARD TensorOptions
device_index(int16_t device_index) const noexcept {
return device(Device::Type::CUDA, device_index);
}
/// Return a copy of `TensorOptions` with `dtype` set to the given one.
C10_NODISCARD TensorOptions dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept {
C10_NODISCARD TensorOptions
dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept {
TensorOptions r = *this;
r.set_dtype(dtype);
return r;
}
// legacy function to support ScalarType
C10_NODISCARD TensorOptions dtype(c10::optional<ScalarType> dtype) const noexcept {
C10_NODISCARD TensorOptions
dtype(c10::optional<ScalarType> dtype) const noexcept {
TensorOptions r = *this;
r.set_dtype(dtype);
return r;
@ -225,28 +240,32 @@ struct C10_API TensorOptions {
}
/// Sets the layout of the `TensorOptions`.
C10_NODISCARD TensorOptions layout(c10::optional<Layout> layout) const noexcept {
C10_NODISCARD TensorOptions
layout(c10::optional<Layout> layout) const noexcept {
TensorOptions r = *this;
r.set_layout(layout);
return r;
}
/// Sets the `requires_grad` property of the `TensorOptions`.
C10_NODISCARD TensorOptions requires_grad(c10::optional<bool> requires_grad) const noexcept {
C10_NODISCARD TensorOptions
requires_grad(c10::optional<bool> requires_grad) const noexcept {
TensorOptions r = *this;
r.set_requires_grad(requires_grad);
return r;
}
/// Sets the `pinned_memory` property on the `TensorOptions`.
C10_NODISCARD TensorOptions pinned_memory(c10::optional<bool> pinned_memory) const noexcept {
C10_NODISCARD TensorOptions
pinned_memory(c10::optional<bool> pinned_memory) const noexcept {
TensorOptions r = *this;
r.set_pinned_memory(pinned_memory);
return r;
}
/// Sets the `memory_format` property on `TensorOptions`.
C10_NODISCARD TensorOptions memory_format(c10::optional<MemoryFormat> memory_format) const noexcept {
C10_NODISCARD TensorOptions
memory_format(c10::optional<MemoryFormat> memory_format) const noexcept {
TensorOptions r = *this;
r.set_memory_format(memory_format);
return r;
@ -343,13 +362,15 @@ struct C10_API TensorOptions {
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
return computeDispatchKey() == other.computeDispatchKey() &&
typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
}
/// Returns the `pinned_memory` property of the `TensorOptions`, or
/// `c10::nullopt` if `pinned_memory` is not specified.
c10::optional<bool> pinned_memory_opt() const noexcept {
return has_pinned_memory_ ? c10::make_optional(pinned_memory_) : c10::nullopt;
return has_pinned_memory_ ? c10::make_optional(pinned_memory_)
: c10::nullopt;
}
/// Returns whether the `memory_layout` is specified
@ -363,7 +384,8 @@ struct C10_API TensorOptions {
/// Returns the `memory_layout` property of `TensorOptions, or
/// `c10::nullopt` if `memory_format` is not specified.
c10::optional<MemoryFormat> memory_format_opt() const noexcept {
return has_memory_format_ ? c10::make_optional(memory_format_) : c10::nullopt;
return has_memory_format_ ? c10::make_optional(memory_format_)
: c10::nullopt;
}
// Resolves the ATen backend specified by the current construction axes.
@ -384,18 +406,25 @@ struct C10_API TensorOptions {
///
TensorOptions merge_in(TensorOptions options) const noexcept {
TensorOptions merged = *this;
if (options.has_device()) merged.set_device(options.device_opt());
if (options.has_dtype()) merged.set_dtype(options.dtype_opt());
if (options.has_layout()) merged.set_layout(options.layout_opt());
if (options.has_device())
merged.set_device(options.device_opt());
if (options.has_dtype())
merged.set_dtype(options.dtype_opt());
if (options.has_layout())
merged.set_layout(options.layout_opt());
// NB: requires grad is right biased; not a logical AND/OR!
if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt());
if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt());
if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt());
if (options.has_requires_grad())
merged.set_requires_grad(options.requires_grad_opt());
if (options.has_pinned_memory())
merged.set_pinned_memory(options.pinned_memory_opt());
if (options.has_memory_format())
merged.set_memory_format(options.memory_format_opt());
return merged;
}
// TODO remove after TensorOptions rationalization
TensorOptions merge_memory_format(c10::optional<MemoryFormat> optional_memory_format) const noexcept {
TensorOptions merge_memory_format(
c10::optional<MemoryFormat> optional_memory_format) const noexcept {
TensorOptions merged = *this;
if (optional_memory_format.has_value()) {
merged.set_memory_format(*optional_memory_format);
@ -408,11 +437,11 @@ struct C10_API TensorOptions {
// the most part, this just means that this function never returns an
// Autograd key)
DispatchKey computeDispatchKey() const {
return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
return c10::computeDispatchKey(
optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
}
private:
// These methods are currently private because I'm not sure if it's wise
// to actually publish them. They are methods because I need them in
// the constructor and the functional API implementation.
@ -515,13 +544,12 @@ struct C10_API TensorOptions {
// Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
// for that matter)
bool requires_grad_ : 1;
bool pinned_memory_ : 1;
bool requires_grad_ : 1;
bool pinned_memory_ : 1;
bool has_device_ : 1;
bool has_dtype_ : 1;
bool has_layout_ : 1;
bool has_device_ : 1;
bool has_dtype_ : 1;
bool has_layout_ : 1;
bool has_requires_grad_ : 1;
bool has_pinned_memory_ : 1;
bool has_memory_format_ : 1;
@ -530,8 +558,9 @@ struct C10_API TensorOptions {
// We should aspire to fit in one machine-size word; but a size greater than two
// words is too much. (We are doing terribly on 32-bit archs, where we require
// three machine size words to store tensor options. Eek!)
static_assert( sizeof(TensorOptions) <= sizeof(int64_t) * 2,
"TensorOptions must fit in 128-bits" );
static_assert(
sizeof(TensorOptions) <= sizeof(int64_t) * 2,
"TensorOptions must fit in 128-bits");
/// Convenience function that returns a `TensorOptions` object with the `dtype`
/// set to the given one.
@ -591,88 +620,106 @@ inline std::string toString(const TensorOptions options) {
// This is intended to be a centralized location by which we can determine
// what an appropriate DispatchKey for a tensor is.
inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device) {
inline DispatchKey computeDispatchKey(
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device) {
const auto layout_ = layout_or_default(layout);
const auto device_ = device_or_default(device);
switch (layout_) {
case Layout::Strided: {
const auto dtype_ = dtype_or_default(dtype);
switch (device_.type()) {
case DeviceType::CPU: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedCPU;
}
return DispatchKey::CPU;
case Layout::Strided: {
const auto dtype_ = dtype_or_default(dtype);
switch (device_.type()) {
case DeviceType::CPU: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedCPU;
}
case DeviceType::CUDA: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedCUDA;
}
return DispatchKey::CUDA;
}
case DeviceType::XPU: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedXPU;
}
return DispatchKey::XPU;
}
case DeviceType::MKLDNN:
case DeviceType::OPENGL:
case DeviceType::OPENCL:
case DeviceType::IDEEP:
TORCH_INTERNAL_ASSERT(0, "This is a grandfathered Caffe2 device type ", device_.type(), ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
case DeviceType::HIP:
return DispatchKey::HIP;
case DeviceType::FPGA:
return DispatchKey::FPGA;
case DeviceType::MSNPU:
return DispatchKey::MSNPU;
case DeviceType::XLA:
return DispatchKey::XLA;
case DeviceType::MLC:
return DispatchKey::MLC;
case DeviceType::Vulkan:
return DispatchKey::Vulkan;
case DeviceType::Metal:
return DispatchKey::Metal;
case DeviceType::Meta:
return DispatchKey::Meta;
default:
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for dense layout: ", device_.type());
return DispatchKey::CPU;
}
case DeviceType::CUDA: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedCUDA;
}
return DispatchKey::CUDA;
}
case DeviceType::XPU: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedXPU;
}
return DispatchKey::XPU;
}
case DeviceType::MKLDNN:
case DeviceType::OPENGL:
case DeviceType::OPENCL:
case DeviceType::IDEEP:
TORCH_INTERNAL_ASSERT(
0,
"This is a grandfathered Caffe2 device type ",
device_.type(),
", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
case DeviceType::HIP:
return DispatchKey::HIP;
case DeviceType::FPGA:
return DispatchKey::FPGA;
case DeviceType::MSNPU:
return DispatchKey::MSNPU;
case DeviceType::XLA:
return DispatchKey::XLA;
case DeviceType::MLC:
return DispatchKey::MLC;
case DeviceType::Vulkan:
return DispatchKey::Vulkan;
case DeviceType::Metal:
return DispatchKey::Metal;
case DeviceType::Meta:
return DispatchKey::Meta;
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for dense layout: ",
device_.type());
}
case Layout::Sparse:
switch (device_.type()) {
case DeviceType::CPU:
return DispatchKey::SparseCPU;
case DeviceType::CUDA:
return DispatchKey::SparseCUDA;
case DeviceType::HIP:
return DispatchKey::SparseHIP;
case DeviceType::XPU:
return DispatchKey::SparseXPU;
default:
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for sparse layout: ", device_.type());
}
case Layout::Mkldnn:
switch (device_.type()) {
case DeviceType::CPU:
return DispatchKey::MkldnnCPU;
default:
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for mkldnn layout: ", device_.type());
}
case Layout::SparseCsr:
switch(device_.type()) {
case DeviceType::CPU:
return DispatchKey::SparseCsrCPU;
case DeviceType::CUDA:
return DispatchKey::SparseCsrCUDA;
default:
AT_ERROR("Unsupported device type for sparse CSR layout: ", device_.type());
}
default:
TORCH_CHECK(false, "Unsupported layout: ", layout_);
}
case Layout::Sparse:
switch (device_.type()) {
case DeviceType::CPU:
return DispatchKey::SparseCPU;
case DeviceType::CUDA:
return DispatchKey::SparseCUDA;
case DeviceType::HIP:
return DispatchKey::SparseHIP;
case DeviceType::XPU:
return DispatchKey::SparseXPU;
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for sparse layout: ",
device_.type());
}
case Layout::Mkldnn:
switch (device_.type()) {
case DeviceType::CPU:
return DispatchKey::MkldnnCPU;
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for mkldnn layout: ",
device_.type());
}
case Layout::SparseCsr:
switch (device_.type()) {
case DeviceType::CPU:
return DispatchKey::SparseCsrCPU;
case DeviceType::CUDA:
return DispatchKey::SparseCsrCUDA;
default:
AT_ERROR(
"Unsupported device type for sparse CSR layout: ",
device_.type());
}
default:
TORCH_CHECK(false, "Unsupported layout: ", layout_);
}
}
inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
@ -692,7 +739,7 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
}
inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
switch(dispatch_key) {
switch (dispatch_key) {
// stuff that's real
case DispatchKey::CPU:
case DispatchKey::SparseCPU:
@ -730,14 +777,18 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
case DispatchKey::MSNPU:
return DeviceType::MSNPU;
default:
TORCH_CHECK(false, "DispatchKey ", dispatch_key, " doesn't correspond to a device");
TORCH_CHECK(
false,
"DispatchKey ",
dispatch_key,
" doesn't correspond to a device");
}
}
inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
return TensorOptions()
.layout(dispatchKeyToLayout(dispatch_key))
.device(dispatchKeyToDeviceType(dispatch_key));
.layout(dispatchKeyToLayout(dispatch_key))
.device(dispatchKeyToDeviceType(dispatch_key));
}
} // namespace c10

View File

@ -5,7 +5,7 @@ namespace c10 {
// should this use the globalContext? Can it get a context passed in somehow?
UndefinedTensorImpl::UndefinedTensorImpl()
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
set_storage_access_should_throw();
}
@ -19,7 +19,8 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const {
#ifdef DEBUG
bool UndefinedTensorImpl::has_storage() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
return false;
}
#endif
@ -39,4 +40,4 @@ const char* UndefinedTensorImpl::tensorimpl_type_name() const {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
UndefinedTensorImpl UndefinedTensorImpl::_singleton;
}
} // namespace c10

View File

@ -7,13 +7,14 @@ namespace c10 {
struct C10_API UndefinedTensorImpl final : public TensorImpl {
public:
// Without this, we get:
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in device code
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in
// device code
// (ostensibly because the constexpr tricks MSVC into trying to compile this
// function for device as well).
#ifdef _WIN32
static inline TensorImpl * singleton() {
static inline TensorImpl* singleton() {
#else
static constexpr inline TensorImpl * singleton() {
static constexpr inline TensorImpl* singleton() {
#endif
return &_singleton;
}
@ -24,7 +25,8 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
bool has_storage() const override;
#endif
void set_storage_offset(int64_t offset) override;
private:
private:
UndefinedTensorImpl();
static UndefinedTensorImpl _singleton;
const char* tensorimpl_type_name() const override;

View File

@ -4,10 +4,17 @@
namespace c10 {
static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
static inline int64_t maybe_wrap_dim(
int64_t dim,
int64_t dim_post_expr,
bool wrap_scalar = true) {
if (dim_post_expr <= 0) {
if (!wrap_scalar) {
TORCH_CHECK_INDEX(false, "dimension specified as ", dim, " but tensor has no dimensions");
TORCH_CHECK_INDEX(
false,
"dimension specified as ",
dim,
" but tensor has no dimensions");
}
dim_post_expr = 1; // this will make range [-1, 0]
}
@ -15,12 +22,19 @@ static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wr
int64_t min = -dim_post_expr;
int64_t max = dim_post_expr - 1;
if (dim < min || dim > max) {
TORCH_CHECK_INDEX(false,
"Dimension out of range (expected to be in range of [",
min, ", ", max, "], but got ", dim, ")");
TORCH_CHECK_INDEX(
false,
"Dimension out of range (expected to be in range of [",
min,
", ",
max,
"], but got ",
dim,
")");
}
if (dim < 0) dim += dim_post_expr;
if (dim < 0)
dim += dim_post_expr;
return dim;
}
}
} // namespace c10

View File

@ -5,11 +5,15 @@ namespace impl {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::atomic<const DeviceGuardImplInterface*>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
device_guard_impl_registry[static_cast<size_t>(
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(DeviceType type, const DeviceGuardImplInterface* impl) {
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
DeviceType type,
const DeviceGuardImplInterface* impl) {
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
}
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -30,16 +30,16 @@ class DataPtr;
* should map one-to-one with actual event flags for those backends.
*/
enum class EventFlag {
PYTORCH_DEFAULT,
BACKEND_DEFAULT,
// CUDA flags
CUDA_EVENT_DEFAULT,
CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA
// HIP flags
HIP_EVENT_DEFAULT,
HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP
// FOR TESTING ONLY
INVALID
PYTORCH_DEFAULT,
BACKEND_DEFAULT,
// CUDA flags
CUDA_EVENT_DEFAULT,
CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA
// HIP flags
HIP_EVENT_DEFAULT,
HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP
// FOR TESTING ONLY
INVALID
};
namespace impl {
@ -126,47 +126,44 @@ struct C10_API DeviceGuardImplInterface {
*/
virtual Stream exchangeStream(Stream) const noexcept = 0;
/**
* Destroys the given event.
*/
virtual void destroyEvent (
void* event,
const DeviceIndex device_index) const noexcept { }
/**
* Destroys the given event.
*/
virtual void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept {}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
virtual void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const c10::EventFlag flag) const {
void** event,
const Stream& stream,
const DeviceIndex device_index,
const c10::EventFlag flag) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
virtual void block(
void* event,
const Stream& stream) const {
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
virtual void block(void* event, const Stream& stream) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
virtual bool queryEvent(void* event) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
@ -178,13 +175,13 @@ struct C10_API DeviceGuardImplInterface {
*/
virtual DeviceIndex deviceCount() const noexcept = 0;
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const { }
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
}
/**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
@ -228,23 +225,21 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
}
// Event-related functions
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(false, D, " backend doesn't support events.");
}
void block(
void* event,
const Stream& stream) const override {
void block(void* event, const Stream& stream) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
bool queryEvent(void* event) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {}
};
// The registry is NON-owning. Each stored pointer is std::atomic so
@ -262,7 +257,8 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
// putting them in the registry. This is done by deleting the destructor
// on DeviceGuardImplInterface.
extern C10_API std::atomic<const DeviceGuardImplInterface*>
device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
device_guard_impl_registry[static_cast<size_t>(
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
// I can't conveniently use c10/util/Registry.h for the following reason:
// c10/util/Registry.h gives me a slow way of Create'ing a object of some
@ -273,12 +269,13 @@ device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVI
// into device_guard_impl_registry.
class C10_API DeviceGuardImplRegistrar {
public:
public:
DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
};
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
// Two adjacent int16_t fields DeviceType and DeviceIndex has field access
@ -301,4 +298,5 @@ inline bool hasDeviceGuardImpl(DeviceType type) {
return device_guard_impl_registry[static_cast<size_t>(type)].load();
}
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -60,17 +60,16 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
// Event-related functions
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override { }
void block(
void* event,
const Stream& stream) const override { }
bool queryEvent(void* event) const override { return true; }
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {}
void block(void* event, const Stream& stream) const override {}
bool queryEvent(void* event) const override {
return true;
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {}
// Convenience methods for testing
static DeviceIndex getDeviceIndex() {
@ -87,9 +86,11 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
static void resetStreams() {
current_streams_.fill(0);
}
private:
private:
thread_local static DeviceIndex current_device_;
thread_local static std::array<StreamId, kFakeGuardImplMaxDevices> current_streams_;
thread_local static std::array<StreamId, kFakeGuardImplMaxDevices>
current_streams_;
};
template <DeviceType T>
@ -99,7 +100,8 @@ template <DeviceType T>
constexpr DeviceType FakeGuardImpl<T>::static_type;
template <DeviceType T>
thread_local std::array<StreamId, kFakeGuardImplMaxDevices> FakeGuardImpl<T>::current_streams_ = {0,0,0,0,0,0,0,0};
thread_local std::array<StreamId, kFakeGuardImplMaxDevices>
FakeGuardImpl<T>::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0};
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -1,18 +1,17 @@
#pragma once
// This file provides implementations of InlineDeviceGuard and InlineOptionalDeviceGuard.
// This file provides implementations of InlineDeviceGuard and
// InlineOptionalDeviceGuard.
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/Optional.h>
#include <c10/util/C++17.h>
#include <c10/util/Optional.h>
namespace c10 {
namespace impl {
/**
* A DeviceGuard is an RAII class that sets a device to some value
* on construction, and resets the device to its original value on
@ -54,7 +53,7 @@ namespace impl {
*/
template <typename T>
class InlineDeviceGuard {
public:
public:
// Note [Omitted default constructor from RAII]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In principle, we could add a default constructor to
@ -69,25 +68,36 @@ public:
/// Set the current device to the passed Device.
explicit InlineDeviceGuard(Device device)
: impl_(device.type())
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
, current_device_(device.index() == -1 ? original_device_ : device)
{}
: impl_(device.type()),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Set the current device index to the passed DeviceIndex. (The
/// device type is inferred from the template parameter T).
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineDeviceGuard(DeviceIndex device_index)
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
/// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
/// DeviceGuardImplInterface pointer.
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineDeviceGuard(Device device, const DeviceGuardImplInterface* impl)
: impl_(VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type())))
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
, current_device_(device.index() == -1 ? original_device_ : device)
{}
template <
typename U = T,
typename = typename std::enable_if<
std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineDeviceGuard(
Device device,
const DeviceGuardImplInterface* impl)
: impl_(
VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Copy is disallowed
InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
@ -103,12 +113,17 @@ public:
}
/// Sets the device to the given one.
template <typename U=T, typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::type = 0>
template <
typename U = T,
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::
type = 0>
void set_device(at::Device device) {
AT_ASSERT((U::static_type == DeviceType::HIP && device.is_cuda()) ||
device.type() == U::static_type);
AT_ASSERT(
(U::static_type == DeviceType::HIP && device.is_cuda()) ||
device.type() == U::static_type);
auto index = device.index();
if (index == -1) return;
if (index == -1)
return;
impl_.setDevice(device);
current_device_ = device;
}
@ -116,8 +131,8 @@ public:
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device. This is effectively equivalent to
/// set_device when a guard supports only a single device type.
template <typename U=T>
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type
template <typename U = T>
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type
reset_device(at::Device device) {
set_device(device);
}
@ -139,11 +154,14 @@ public:
/// that it is unnecessary.
///
/// Optional argument is for testing only.
template <typename U=T>
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value >::type
reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl = nullptr) {
template <typename U = T>
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type
reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl = nullptr) {
auto index = device.index();
if (index == -1) return;
if (index == -1)
return;
if (device.type() == original_device_.type()) {
AT_ASSERT(impl == nullptr || impl->type() == device.type());
impl_.setDevice(device);
@ -175,10 +193,10 @@ public:
return current_device_;
}
protected:
protected:
T impl_;
private:
private:
Device original_device_;
Device current_device_;
};
@ -193,29 +211,33 @@ private:
*/
template <typename T>
class InlineOptionalDeviceGuard {
public:
public:
// Note [Explicit initialization of optional fields]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Explicit initialization of optional fields
// required to workaround an nvcc bug; see https://github.com/pytorch/pytorch/issues/12117
// required to workaround an nvcc bug; see
// https://github.com/pytorch/pytorch/issues/12117
/// Creates an uninitialized OptionalDeviceGuard.
explicit InlineOptionalDeviceGuard()
: guard_() // See Note [Explicit initialization of optional fields]
{}
: guard_() // See Note [Explicit initialization of optional fields]
{}
/// Set the current device to the passed Device, if it is not nullopt.
explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_opt.has_value()) {
guard_.emplace(device_opt.value());
}
}
/// Set the current device to the passed DeviceIndex, if it is not nullopt.
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_index_opt.has_value()) {
guard_.emplace(device_index_opt.value());
}
@ -225,7 +247,7 @@ public:
/// and result in initialized OptionalDeviceGuard.
template <typename... Args>
explicit InlineOptionalDeviceGuard(Args&&... args)
: guard_(in_place, std::forward<Args>(args)...) {}
: guard_(in_place, std::forward<Args>(args)...) {}
// TODO: Consider readding Tensor and TensorList constructors here, when
// Tensor moves to c10. (These are only valid on OptionalDeviceGuard,
@ -313,11 +335,15 @@ public:
//
// We could solve this with an extra thread-local variable. But no one is
// actually using move-assignment. So just get rid of it.
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete;
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) =
delete;
/// Sets the device to the given one. Initializes OptionalDeviceGuard if it
/// is not already initialized.
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void set_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
@ -333,8 +359,13 @@ public:
/// See notes on why this is called reset_device on InlineDeviceGuard.
///
/// Optional argument is for testing only.
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
void reset_device(at::Device device, const DeviceGuardImplInterface* impl = nullptr) {
template <
typename U = T,
typename = typename std::enable_if<
std::is_same<U, VirtualGuardImpl>::value>::type>
void reset_device(
at::Device device,
const DeviceGuardImplInterface* impl = nullptr) {
if (!guard_.has_value()) {
guard_.emplace(device, impl);
} else {
@ -346,7 +377,10 @@ public:
/// current device to the passed device. Initializes the guard if it is
/// not already initialized. This is effectively equivalent to set_device
/// when a guard supports only a single device type.
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void reset_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
@ -357,7 +391,10 @@ public:
/// Sets the device index to the given one. The device type is statically
/// known.
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type>
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void set_index(DeviceIndex index) {
if (!guard_.has_value()) {
guard_.emplace(index);
@ -366,17 +403,19 @@ public:
}
}
/// Returns the device that was set immediately prior to initialization of the,
/// guard, or nullopt if the guard is uninitialized.
/// Returns the device that was set immediately prior to initialization of
/// the, guard, or nullopt if the guard is uninitialized.
optional<Device> original_device() const {
return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt;
return guard_.has_value() ? make_optional(guard_->original_device())
: nullopt;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
optional<Device> current_device() const {
return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt;
return guard_.has_value() ? make_optional(guard_->current_device())
: nullopt;
}
/// Restore the original device, resetting this guard to uninitialized state.
@ -384,8 +423,9 @@ public:
guard_.reset();
}
private:
private:
optional<InlineDeviceGuard<T>> guard_;
};
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -2,22 +2,19 @@
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/Exception.h>
namespace c10 {
namespace impl {
template <typename T>
struct InlineEvent final {
InlineEvent() = delete;
InlineEvent(
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: backend_{_device_type},
device_type_{_device_type},
flag_{_flag} { }
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {}
// Copy constructor and copy assignment operator (deleted)
InlineEvent(const InlineEvent&) = delete;
@ -25,7 +22,7 @@ struct InlineEvent final {
// Move constructor and move assignment operator
InlineEvent(InlineEvent&& other)
: InlineEvent(other.device_type_, other.flag_) {
: InlineEvent(other.device_type_, other.flag_) {
swap(std::move(other));
}
InlineEvent& operator=(InlineEvent&& other) {
@ -43,27 +40,36 @@ struct InlineEvent final {
}
~InlineEvent() noexcept {
if (event_) backend_.destroyEvent(event_, device_index_);
if (event_)
backend_.destroyEvent(event_, device_index_);
}
DeviceType device_type() const noexcept { return device_type_; }
DeviceIndex device_index() const noexcept { return device_index_; }
EventFlag flag() const noexcept { return flag_; }
bool was_marked_for_recording() const noexcept { return was_marked_for_recording_; }
DeviceType device_type() const noexcept {
return device_type_;
}
DeviceIndex device_index() const noexcept {
return device_index_;
}
EventFlag flag() const noexcept {
return flag_;
}
bool was_marked_for_recording() const noexcept {
return was_marked_for_recording_;
}
void recordOnce(const Stream& stream) {
if (!was_marked_for_recording_) record(stream);
if (!was_marked_for_recording_)
record(stream);
}
void record(const Stream& stream) {
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match recording stream's device type ",
DeviceTypeName(stream.device_type()),
".");
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match recording stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.record(&event_, stream, device_index_, flag_);
was_marked_for_recording_ = true;
@ -71,25 +77,27 @@ struct InlineEvent final {
}
void block(const Stream& stream) const {
if (!was_marked_for_recording_) return;
if (!was_marked_for_recording_)
return;
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match blocking stream's device type ",
DeviceTypeName(stream.device_type()),
".");
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match blocking stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.block(event_, stream);
}
bool query() const {
if (!was_marked_for_recording_) return true;
if (!was_marked_for_recording_)
return true;
return backend_.queryEvent(event_);
}
private:
private:
void* event_ = nullptr;
T backend_;
DeviceType device_type_;
@ -98,5 +106,5 @@ private:
bool was_marked_for_recording_ = false;
};
} // impl
} // c10
} // namespace impl
} // namespace c10

View File

@ -16,27 +16,34 @@ namespace impl {
*/
template <typename T>
class InlineStreamGuard : private InlineDeviceGuard<T> {
public:
public:
/// No default constructor, see Note [Omitted default constructor from RAII]
explicit InlineStreamGuard() = delete;
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
explicit InlineStreamGuard(Stream stream)
: InlineDeviceGuard<T>(stream.device())
, original_stream_of_original_device_(this->impl_.getStream(original_device()))
, original_stream_of_current_device_(this->impl_.exchangeStream(stream))
, current_stream_(stream)
{}
: InlineDeviceGuard<T>(stream.device()),
original_stream_of_original_device_(
this->impl_.getStream(original_device())),
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
current_stream_(stream) {}
/// This constructor exists purely for testing
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineStreamGuard(Stream stream, const DeviceGuardImplInterface* impl)
: InlineDeviceGuard<T>(stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type()))
, original_stream_of_original_device_(this->impl_.getStream(original_device()))
, original_stream_of_current_device_(this->impl_.exchangeStream(stream))
, current_stream_(stream)
{}
template <
typename U = T,
typename = typename std::enable_if<
std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineStreamGuard(
Stream stream,
const DeviceGuardImplInterface* impl)
: InlineDeviceGuard<T>(
stream.device(),
impl ? impl : getDeviceGuardImpl(stream.device_type())),
original_stream_of_original_device_(
this->impl_.getStream(original_device())),
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
current_stream_(stream) {}
/// Copy is disallowed
InlineStreamGuard(const InlineStreamGuard<T>&) = delete;
@ -110,8 +117,9 @@ public:
return InlineDeviceGuard<T>::original_device();
}
private:
Stream original_stream_of_original_device_; // what the user probably cares about
private:
Stream
original_stream_of_original_device_; // what the user probably cares about
Stream original_stream_of_current_device_; // what we need to restore
Stream current_stream_;
};
@ -123,17 +131,16 @@ private:
*/
template <typename T>
class InlineOptionalStreamGuard {
public:
public:
/// Creates an uninitialized stream guard.
explicit InlineOptionalStreamGuard()
: guard_() // See Note [Explicit initialization of optional fields]
{}
: guard_() // See Note [Explicit initialization of optional fields]
{}
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit InlineOptionalStreamGuard(optional<Stream> stream_opt)
: guard_() {
explicit InlineOptionalStreamGuard(optional<Stream> stream_opt) : guard_() {
if (stream_opt.has_value()) {
guard_.emplace(stream_opt.value());
}
@ -142,13 +149,14 @@ public:
/// All constructors of StreamGuard are valid for OptionalStreamGuard
template <typename... Args>
explicit InlineOptionalStreamGuard(Args&&... args)
: guard_(in_place, std::forward<Args>(args)...) {}
: guard_(in_place, std::forward<Args>(args)...) {}
// See Note [Move construction for RAII guards is tricky]
InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete;
InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) =
delete;
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
@ -166,28 +174,31 @@ public:
/// Returns the stream that was set at the time the guard was most recently
/// initialized, or nullopt if the guard is uninitialized.
optional<Stream> original_stream() const {
return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt;
return guard_.has_value() ? make_optional(guard_->original_stream())
: nullopt;
}
/// Returns the most recent stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
optional<Stream> current_stream() const {
return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt;
return guard_.has_value() ? make_optional(guard_->current_stream())
: nullopt;
}
/// Restore the original device and stream, resetting this guard to uninitialized state.
/// Restore the original device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
private:
optional<InlineStreamGuard<T>> guard_;
};
template <typename T>
class InlineMultiStreamGuard {
public:
public:
/// Calls `set_stream` on each of the streams in the list.
/// This may be useful if you need to set different streams
/// for different devices.
@ -216,10 +227,10 @@ public:
}
}
protected:
protected:
optional<T> impl_;
private:
private:
/// The original streams that were active on all devices.
std::vector<Stream> original_streams_;
@ -228,13 +239,17 @@ private:
DeviceType type = streams[0].device_type();
for (size_t idx = 1; idx < streams.size(); idx++) {
TORCH_CHECK_VALUE(
streams[idx].device_type() == type,
"Streams have a mix of device types: stream 0 is on ",
streams[0].device(), " while stream ", idx, " is on device ",
streams[idx].device());
streams[idx].device_type() == type,
"Streams have a mix of device types: stream 0 is on ",
streams[0].device(),
" while stream ",
idx,
" is on device ",
streams[idx].device());
}
return type;
}
};
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -8,12 +8,12 @@ namespace impl {
// NB: POD, must be zero initialized!
// Note [TLS Initialization]
// We wanted raw_local_dispatch_key_set to be initialized with non-zero state
// e.g. BackendSelect and InplaceOrView in included set. But certain Windows compiler (e.g the one
// used in ARVR tests) only allow TLS to be zero-initialized.
// To preserve the invariant that raw TLS storage of the default state is zero,
// we obtain the actual include keyset by XORing raw_local_dispatch_key_set.included_
// with c10::default_included_set. This logic is encapsulated in struct
// PODLocalDispatchKeySet.
// e.g. BackendSelect and InplaceOrView in included set. But certain Windows
// compiler (e.g the one used in ARVR tests) only allow TLS to be
// zero-initialized. To preserve the invariant that raw TLS storage of the
// default state is zero, we obtain the actual include keyset by XORing
// raw_local_dispatch_key_set.included_ with c10::default_included_set. This
// logic is encapsulated in struct PODLocalDispatchKeySet.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
@ -28,26 +28,29 @@ void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
}
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as
// opposed to only snapshotting and restoring the state of its assigned DispatchKeySet.
// I'm not sure which is better. If only the RAII API is used, the two choices are
// not distinguishable.
// An RAII guard could snapshot and restore the entire state (entire
// DispatchKeySet) as opposed to only snapshotting and restoring the state of
// its assigned DispatchKeySet. I'm not sure which is better. If only the RAII
// API is used, the two choices are not distinguishable.
//
// However, if the guard chooses to snapshot and restore the entire DispatchKeySet,
// the interaction with the non-RAII API changes. Consider this sequence of events:
// - An RAII guard is declared for a particular DispatchKeySet, but snapshots the entire
// However, if the guard chooses to snapshot and restore the entire
// DispatchKeySet, the interaction with the non-RAII API changes. Consider this
// sequence of events:
// - An RAII guard is declared for a particular DispatchKeySet, but snapshots
// the entire
// current DispatchKeySet.
// - A call to the non-RAII API changes the state for DispatchKeys outside the assigned
// - A call to the non-RAII API changes the state for DispatchKeys outside the
// assigned
// set.
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted
// (which restores the state for its own assigned DispatchKey and wipes out the state
// for the other DispatchKeys set by the non-RAII API).
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it
// snapshotted
// (which restores the state for its own assigned DispatchKey and wipes out
// the state for the other DispatchKeys set by the non-RAII API).
// RAII API
IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include)
: tls_(&raw_local_dispatch_key_set)
, include_(include - tls_->included()) {
: tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) {
if (!include_.empty()) {
tls_->set_included(tls_->included() | include_);
}
@ -60,8 +63,7 @@ IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
}
ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude)
: tls_(&raw_local_dispatch_key_set)
, exclude_(exclude - tls_->excluded()) {
: tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) {
if (!exclude_.empty()) {
tls_->set_excluded(tls_->excluded() | exclude_);
}
@ -74,7 +76,8 @@ ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
}
// Non-RAII API
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details.
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h
// for details.
bool tls_is_dispatch_key_excluded(DispatchKey x) {
return raw_local_dispatch_key_set.excluded().has(x);
@ -115,4 +118,5 @@ bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) {
bool tls_is_dispatch_keyset_included(DispatchKeySet ks) {
return raw_local_dispatch_key_set.included().isSupersetOf(ks);
}
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -36,10 +36,12 @@ struct C10_API PODLocalDispatchKeySet {
// See Note [TLS Initialization]
DispatchKeySet included() const {
return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set;
return DispatchKeySet(DispatchKeySet::RAW, included_) ^
c10::default_included_set;
}
DispatchKeySet excluded() const {
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set;
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
c10::default_excluded_set;
}
void set_included(DispatchKeySet x) {
@ -49,11 +51,13 @@ struct C10_API PODLocalDispatchKeySet {
excluded_ = (x ^ c10::default_excluded_set).raw_repr();
}
};
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
static_assert(
std::is_pod<PODLocalDispatchKeySet>::value,
"PODLocalDispatchKeySet must be a POD type.");
struct C10_API LocalDispatchKeySet {
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
: included_(x.included()), excluded_(x.excluded()) {}
: included_(x.included()), excluded_(x.excluded()) {}
DispatchKeySet included_;
DispatchKeySet excluded_;
};
@ -63,7 +67,7 @@ struct C10_API LocalDispatchKeySet {
#if defined(_MSC_VER) || defined(C10_ANDROID)
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
#else // defined(_MSC_VER) || defined(C10_ANDROID)
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
// Don't let people fiddle with the thread_local directly just
@ -78,15 +82,17 @@ C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
// RAII API for manipulating the thread-local dispatch state.
class C10_API IncludeDispatchKeyGuard {
public:
public:
IncludeDispatchKeyGuard(DispatchKeySet);
IncludeDispatchKeyGuard(DispatchKey k) : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
IncludeDispatchKeyGuard(DispatchKey k)
: IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
~IncludeDispatchKeyGuard();
private:
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
@ -94,15 +100,17 @@ private:
};
class C10_API ExcludeDispatchKeyGuard {
public:
public:
ExcludeDispatchKeyGuard(DispatchKeySet);
ExcludeDispatchKeyGuard(DispatchKey k) : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
ExcludeDispatchKeyGuard(DispatchKey k)
: ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
~ExcludeDispatchKeyGuard();
private:
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
@ -120,7 +128,8 @@ private:
// through that DispatchKey's registered overrides.
//
// The non-RAII API is less efficient than the RAII guards because both the
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!)
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs
// one!)
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
@ -129,4 +138,5 @@ C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -3,9 +3,13 @@
namespace c10 {
namespace impl {
void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) {
void SizesAndStrides::resizeSlowPath(
const size_t newSize,
const size_t oldSize) {
if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!isInline(),
"resizeSlowPath called when fast path should have been hit!");
int64_t* tempStorage = outOfLineStorage_;
memcpy(
&inlineStorage_[0],
@ -24,15 +28,23 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize)
// CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD
// OVERWRITE inlineStorage_!
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
int64_t* tempStorage = static_cast<int64_t *>(malloc(storageBytes(newSize)));
TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!");
int64_t* tempStorage =
static_cast<int64_t*>(malloc(storageBytes(newSize)));
TORCH_CHECK(
tempStorage,
"Could not allocate memory to change Tensor SizesAndStrides!");
const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]);
const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0;
const auto bytesToZero = (newSize > oldSize)
? (newSize - oldSize) * sizeof(tempStorage[0])
: 0;
memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy);
if (bytesToZero) {
memset(&tempStorage[oldSize], 0, bytesToZero);
}
memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy);
memcpy(
&tempStorage[newSize],
&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE],
bytesToCopy);
if (bytesToZero) {
memset(&tempStorage[newSize + oldSize], 0, bytesToZero);
}
@ -55,7 +67,8 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize)
resizeOutOfLineStorage(newSize);
} else {
// Zero the end of the sizes portion.
const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]);
const auto bytesToZero =
(newSize - oldSize) * sizeof(outOfLineStorage_[0]);
memset(&outOfLineStorage_[oldSize], 0, bytesToZero);
memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero);
}

View File

@ -19,7 +19,8 @@ namespace impl {
// the number of strides. The memory layout is as follows:
//
// 1 size_t for the size
// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array
// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer
// to out-of-line array
class C10_API SizesAndStrides {
public:
// TODO: different iterator types for sizes & strides to prevent
@ -238,11 +239,16 @@ class C10_API SizesAndStrides {
if (newSize == oldSize) {
return;
}
if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
if (C10_LIKELY(
newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
if (oldSize < newSize) {
const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]);
const auto bytesToZero =
(newSize - oldSize) * sizeof(inlineStorage_[0]);
memset(&inlineStorage_[oldSize], 0, bytesToZero);
memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero);
memset(
&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize],
0,
bytesToZero);
}
size_ = newSize;
} else {
@ -267,14 +273,19 @@ class C10_API SizesAndStrides {
}
void allocateOutOfLineStorage(size_t size) {
outOfLineStorage_ = static_cast<int64_t *>(malloc(storageBytes(size)));
TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!");
outOfLineStorage_ = static_cast<int64_t*>(malloc(storageBytes(size)));
TORCH_CHECK(
outOfLineStorage_,
"Could not allocate memory for Tensor SizesAndStrides!");
}
void resizeOutOfLineStorage(size_t newSize) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline());
outOfLineStorage_ = static_cast<int64_t *>(realloc(outOfLineStorage_, storageBytes(newSize)));
TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!");
outOfLineStorage_ = static_cast<int64_t*>(
realloc(outOfLineStorage_, storageBytes(newSize)));
TORCH_CHECK(
outOfLineStorage_,
"Could not allocate memory for Tensor SizesAndStrides!");
}
void copyDataOutline(const SizesAndStrides& rhs) noexcept {
@ -283,10 +294,9 @@ class C10_API SizesAndStrides {
size_t size_;
union {
int64_t *outOfLineStorage_;
int64_t* outOfLineStorage_;
int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{};
};
};
} // namespace impl

View File

@ -10,12 +10,11 @@ namespace impl {
* to virtual dispatch on the DeviceGuardImpl registry.
*/
class VirtualGuardImpl final : public DeviceGuardImplInterface {
public:
public:
VirtualGuardImpl(DeviceType device_type)
: impl_(getDeviceGuardImpl(device_type)) {}
: impl_(getDeviceGuardImpl(device_type)) {}
// This constructor exists purely for testing
VirtualGuardImpl(const DeviceGuardImplInterface* impl)
: impl_(impl) {}
VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
// Copying and moving is OK!
@ -51,34 +50,32 @@ public:
}
// Event functions
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
impl_->record(event, stream, device_index, flag);
}
void block(
void* event,
const Stream& stream) const override {
void block(void* event, const Stream& stream) const override {
impl_->block(event, stream);
}
bool queryEvent(void* event) const override {
return impl_->queryEvent(event);
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override {
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
impl_->destroyEvent(event, device_index);
}
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const Stream& stream) const override {
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
impl_->recordDataPtrOnStream(data_ptr, stream);
}
private:
private:
const DeviceGuardImplInterface* impl_ = nullptr;
};
}} // namespace c10::impl
} // namespace impl
} // namespace c10

View File

@ -3,9 +3,9 @@
namespace c10 {
ThreadPool::ThreadPool(
int pool_size,
int numa_node_id,
std::function<void()> init_thread)
int pool_size,
int numa_node_id,
std::function<void()> init_thread)
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
running_(true),
complete_(true),
@ -13,7 +13,7 @@ ThreadPool::ThreadPool(
total_(threads_.size()),
numa_node_id_(numa_node_id) {
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i] = std::thread([this, i, init_thread](){
threads_[i] = std::thread([this, i, init_thread]() {
if (init_thread) {
init_thread();
}

View File

@ -50,9 +50,9 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
const std::function<void(std::size_t)> with_id;
explicit task_element_t(std::function<void()> f)
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
explicit task_element_t(std::function<void(std::size_t)> f)
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
};
std::queue<task_element_t> tasks_;
@ -105,13 +105,11 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
class C10_API TaskThreadPool : public c10::ThreadPool {
public:
explicit TaskThreadPool(
std::size_t pool_size,
int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id, [numa_node_id](){
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id);
}) {}
explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id, [numa_node_id]() {
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id);
}) {}
};
C10_DECLARE_SHARED_REGISTRY(

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
#ifndef THC_DEVICE_ALLOCATOR_INC
#define THC_DEVICE_ALLOCATOR_INC
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Registry.h>
#include <array>
@ -19,7 +19,7 @@ class C10_CUDA_API CUDAOutOfMemoryError : public c10::Error {
// block inside of already allocated area.
class C10_CUDA_API FreeMemoryCallback {
public:
virtual ~FreeMemoryCallback() {};
virtual ~FreeMemoryCallback(){};
virtual bool Execute() = 0;
};
@ -55,7 +55,7 @@ enum struct StatType : uint64_t {
AGGREGATE = 0,
SMALL_POOL = 1,
LARGE_POOL = 2,
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
};
typedef std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)> StatArray;
@ -68,7 +68,8 @@ struct DeviceStats {
StatArray segment;
// COUNT: number of active memory blocks (allocated or used by stream)
StatArray active;
// COUNT: number of inactive, split memory blocks (unallocated but can't be released via cudaFree)
// COUNT: number of inactive, split memory blocks (unallocated but can't be
// released via cudaFree)
StatArray inactive_split;
// SUM: bytes requested by client code
@ -80,14 +81,16 @@ struct DeviceStats {
// SUM: bytes within inactive, split memory blocks
StatArray inactive_split_bytes;
// COUNT: total number of failed calls to CUDA malloc necessitating cache flushes.
// COUNT: total number of failed calls to CUDA malloc necessitating cache
// flushes.
int64_t num_alloc_retries = 0;
// COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush)
int64_t num_ooms = 0;
};
// Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc)..
// Struct containing info of an allocation block (i.e. a fractional part of a
// cudaMalloc)..
struct BlockInfo {
int64_t size = 0;
bool allocated = false;
@ -113,8 +116,11 @@ C10_CUDA_API Allocator* get();
C10_CUDA_API void init(int device_count);
C10_CUDA_API void setMemoryFraction(double fraction, int device);
C10_CUDA_API void emptyCache();
C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
C10_CUDA_API void cacheInfo(
int dev_id,
size_t* cachedAndFree,
size_t* largestBlock);
C10_CUDA_API void* getBaseAllocation(void* ptr, size_t* size);
C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream);
C10_CUDA_API DeviceStats getDeviceStats(int device);
C10_CUDA_API void resetAccumulatedStats(int device);
@ -122,9 +128,10 @@ C10_CUDA_API void resetPeakStats(int device);
C10_CUDA_API std::vector<SegmentInfo> snapshot();
// CUDAGraph interactions
C10_CUDA_API void notifyCaptureBegin(int device,
CaptureId_t graph_id,
MempoolId_t mempool_id);
C10_CUDA_API void notifyCaptureBegin(
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id);
C10_CUDA_API void notifyCaptureEnd(int device, CaptureId_t graph_id);
C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id);
@ -133,6 +140,7 @@ C10_CUDA_API std::mutex* getFreeMutex();
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
} // namespace CUDACachingAllocator
}} // namespace c10::cuda
} // namespace cuda
} // namespace c10
#endif

View File

@ -1,7 +1,7 @@
#pragma once
#include <c10/util/Exception.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cuda.h>
// Note [CHECK macro]
@ -21,7 +21,7 @@
} \
} while (0)
#define C10_CUDA_CHECK_WARN(EXPR) \
#define C10_CUDA_CHECK_WARN(EXPR) \
do { \
cudaError_t __err = EXPR; \
if (__err != cudaSuccess) { \

View File

@ -101,7 +101,9 @@ DeviceIndex device_count() noexcept {
static int count = []() {
try {
auto result = device_count_impl(/*fail_if_no_driver=*/false);
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
TORCH_INTERNAL_ASSERT(
result <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning

View File

@ -27,7 +27,7 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard {
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
}
private:
private:
cudaStreamCaptureMode strictness_;
};
#endif
@ -35,15 +35,18 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// Protects against enum cudaStreamCaptureStatus implementation changes.
// Some compilers seem not to like static_assert without the messages.
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
"unexpected int(cudaStreamCaptureStatusNone) value");
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
"unexpected int(cudaStreamCaptureStatusActive) value");
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
"unexpected int(cudaStreamCaptureStatusNone) value");
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
"unexpected int(cudaStreamCaptureStatusActive) value");
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
#endif
enum class CaptureStatus: int {
enum class CaptureStatus : int {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
@ -54,7 +57,7 @@ enum class CaptureStatus: int {
};
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
switch(status) {
switch (status) {
case CaptureStatus::None:
os << "cudaStreamCaptureStatusNone";
break;
@ -67,9 +70,8 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
break;
#endif
default:
TORCH_INTERNAL_ASSERT(false,
"Unknown CUDA graph CaptureStatus",
int(status));
TORCH_INTERNAL_ASSERT(
false, "Unknown CUDA graph CaptureStatus", int(status));
}
return os;
}
@ -78,13 +80,13 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaStreamCaptureStatus is_capturing;
C10_CUDA_CHECK(cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(),
&is_capturing));
C10_CUDA_CHECK(
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
return CaptureStatus(is_capturing);
#else
return CaptureStatus::None;
#endif
}
} // namespace c10
} // namespace cuda
} // namespace c10

View File

@ -1,16 +1,18 @@
#pragma once
#include <c10/cuda/impl/CUDAGuardImpl.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/impl/CUDAGuardImpl.h>
#include <cstddef>
namespace c10 { namespace cuda {
namespace c10 {
namespace cuda {
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard boilerplate]
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
// boilerplate]
/// A variant of DeviceGuard that is specialized for CUDA. It accepts
/// integer indices (interpreting them as CUDA devices) and is a little
@ -38,22 +40,32 @@ struct CUDAGuard {
/// Sets the CUDA device to the given device. Errors if the given device
/// is not a CUDA device.
void set_device(Device device) { guard_.set_device(device); }
void set_device(Device device) {
guard_.set_device(device);
}
/// Sets the CUDA device to the given device. Errors if the given device
/// is not a CUDA device. (This method is provided for uniformity with
/// DeviceGuard).
void reset_device(Device device) { guard_.reset_device(device); }
void reset_device(Device device) {
guard_.reset_device(device);
}
/// Sets the CUDA device to the given device index.
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
/// Returns the device that was set upon construction of the guard
Device original_device() const { return guard_.original_device(); }
Device original_device() const {
return guard_.original_device();
}
/// Returns the last device that was set via `set_device`, if any, otherwise the
/// device passed during construction.
Device current_device() const { return guard_.current_device(); }
/// Returns the last device that was set via `set_device`, if any, otherwise
/// the device passed during construction.
Device current_device() const {
return guard_.current_device();
}
private:
/// The guard for the current device.
@ -67,11 +79,13 @@ struct OptionalCUDAGuard {
explicit OptionalCUDAGuard() : guard_() {}
/// Set the current CUDA device to the passed Device, if it is not nullopt.
explicit OptionalCUDAGuard(optional<Device> device_opt) : guard_(device_opt) {}
explicit OptionalCUDAGuard(optional<Device> device_opt)
: guard_(device_opt) {}
/// Set the current CUDA device to the passed device index, if it is not
/// nullopt
explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt)
: guard_(device_index_opt) {}
// Copy is not allowed
OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
@ -84,31 +98,45 @@ struct OptionalCUDAGuard {
OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
/// Sets the CUDA device to the given device, initializing the guard if it
/// is not already initialized. Errors if the given device is not a CUDA device.
void set_device(Device device) { guard_.set_device(device); }
/// is not already initialized. Errors if the given device is not a CUDA
/// device.
void set_device(Device device) {
guard_.set_device(device);
}
/// Sets the CUDA device to the given device, initializing the guard if it is
/// not already initialized. Errors if the given device is not a CUDA device.
/// (This method is provided for uniformity with OptionalDeviceGuard).
void reset_device(Device device) { guard_.reset_device(device); }
void reset_device(Device device) {
guard_.reset_device(device);
}
/// Sets the CUDA device to the given device index, initializing the guard if
/// it is not already initialized.
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
/// Returns the device that was set immediately prior to initialization of the
/// guard, or nullopt if the guard is uninitialized.
optional<Device> original_device() const { return guard_.original_device(); }
optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
optional<Device> current_device() const { return guard_.current_device(); }
optional<Device> current_device() const {
return guard_.current_device();
}
/// Restore the original CUDA device, resetting this guard to uninitialized state.
void reset() { guard_.reset(); }
/// Restore the original CUDA device, resetting this guard to uninitialized
/// state.
void reset() {
guard_.reset();
}
private:
private:
c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
};
@ -118,17 +146,17 @@ struct CUDAStreamGuard {
/// No default constructor, see Note [Omitted default constructor from RAII]
explicit CUDAStreamGuard() = delete;
/// Set the current CUDA device to the device associated with the passed stream,
/// and set the current CUDA stream on that device to the passed stream.
/// Errors if the Stream is not a CUDA stream.
/// Set the current CUDA device to the device associated with the passed
/// stream, and set the current CUDA stream on that device to the passed
/// stream. Errors if the Stream is not a CUDA stream.
explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
/// Copy is disallowed
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
/// Move is disallowed, as CUDAStreamGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
/// Move is disallowed, as CUDAStreamGuard does not have an uninitialized
/// state, which is required for moves on types with nontrivial destructors.
CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
@ -144,9 +172,12 @@ struct CUDAStreamGuard {
/// WARNING: reset_stream does NOT preserve previously set streams on
/// different devices. If you need to set streams on multiple devices
/// on CUDA, use CUDAMultiStreamGuard instead.
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the CUDA stream that was set at the time the guard was constructed.
/// Returns the CUDA stream that was set at the time the guard was
/// constructed.
CUDAStream original_stream() const {
return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
}
@ -159,31 +190,36 @@ struct CUDAStreamGuard {
/// Returns the most recent CUDA device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const { return guard_.current_device(); }
Device current_device() const {
return guard_.current_device();
}
/// Returns the CUDA device that was set at the most recent reset_stream(),
/// or otherwise the device at construction time.
Device original_device() const { return guard_.original_device(); }
Device original_device() const {
return guard_.original_device();
}
private:
private:
c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of OptionalStreamGuard that is specialized for CUDA. See CUDAGuard
/// for when you can use this.
/// A variant of OptionalStreamGuard that is specialized for CUDA. See
/// CUDAGuard for when you can use this.
struct OptionalCUDAStreamGuard {
/// Create an uninitialized guard.
explicit OptionalCUDAStreamGuard() : guard_() {}
/// Set the current CUDA device to the device associated with the passed stream,
/// and set the current CUDA stream on that device to the passed stream.
/// Errors if the Stream is not a CUDA stream.
/// Set the current CUDA device to the device associated with the passed
/// stream, and set the current CUDA stream on that device to the passed
/// stream. Errors if the Stream is not a CUDA stream.
explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt) : guard_(stream_opt) {}
explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt)
: guard_(stream_opt) {}
/// Copy is disallowed
OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
@ -200,10 +236,12 @@ struct OptionalCUDAStreamGuard {
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Initializes the guard if it was not previously initialized.
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the CUDA stream that was set at the time the guard was most recently
/// initialized, or nullopt if the guard is uninitialized.
/// Returns the CUDA stream that was set at the time the guard was most
/// recently initialized, or nullopt if the guard is uninitialized.
optional<CUDAStream> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
@ -214,8 +252,8 @@ struct OptionalCUDAStreamGuard {
}
/// Returns the most recent CUDA stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
optional<CUDAStream> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
@ -225,17 +263,20 @@ struct OptionalCUDAStreamGuard {
}
}
/// Restore the original CUDA device and stream, resetting this guard to uninitialized state.
void reset() { guard_.reset(); }
/// Restore the original CUDA device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
private:
c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of MultiStreamGuard that is specialized for CUDA.
struct CUDAMultiStreamGuard {
explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
: guard_(unwrapStreams(streams)) {}
: guard_(unwrapStreams(streams)) {}
/// Copy is disallowed
CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
@ -247,7 +288,7 @@ struct CUDAMultiStreamGuard {
// See Note [Move assignment for RAII guards is tricky]
CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
private:
private:
c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;
static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {

View File

@ -1,6 +1,6 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
@ -204,7 +204,7 @@ static void initGlobalStreamState() {
"). Increase that and recompile.");
// Initializes default streams
for (const auto i: c10::irange(num_gpus)) {
for (const auto i : c10::irange(num_gpus)) {
default_streams[i].device_index = i;
low_priority_counters[i] = 0;
high_priority_counters[i] = 0;
@ -218,7 +218,7 @@ static void initDeviceStreamState(DeviceIndex device_index) {
// with it.
CUDAGuard device_guard{device_index};
for (const auto i: c10::irange(kStreamsPerPool)) {
for (const auto i : c10::irange(kStreamsPerPool)) {
auto& lowpri_stream = low_priority_streams[device_index][i];
auto& hipri_stream = high_priority_streams[device_index][i];
@ -244,7 +244,7 @@ static void initCUDAStreamsOnce() {
// Inits current streams (thread local) to default streams
current_streams =
(LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
for (const auto i: c10::irange(num_gpus)) {
for (const auto i : c10::irange(num_gpus)) {
current_streams[i] = &default_streams[i];
}
}

View File

@ -5,53 +5,53 @@
#include <cuda_runtime_api.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/core/DeviceGuard.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
/*
* Stream pool note.
*
* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
* are backed by cuStreams, but they use several pools to minimize the costs
* associated with creating, retaining, and destroying cuStreams.
*
* There are three pools per device, and a device's pools are lazily created.
*
* The first pool contains only the default stream. When the default stream
* is requested it's returned.
*
* The second pool is the "low priority" or "default priority" streams. In
* HIP builds there is no distinction between streams in this pool and streams
* in the third pool (below). There are 32 of these streams per device, and
* when a stream is requested one of these streams is returned round-robin.
* That is, the first stream requested is at index 0, the second at index 1...
* to index 31, then index 0 again.
*
* This means that if 33 low priority streams are requested, the first and
* last streams requested are actually the same stream (under the covers)
* and kernels enqueued on them cannot run concurrently.
*
* The third pool is the "high priority" streams. The third pool acts like
* the second pool except the streams are created with a higher priority.
*
* These pools suggest that stream users should prefer many short-lived streams,
* as the cost of acquiring and releasing streams is effectively zero. If
* many longer-lived streams are required in performance critical scenarios
* then the functionality here may need to be extended to allow, for example,
* "reserving" a subset of the pool so that other streams do not accidentally
* overlap the performance critical streams.
*
* Note: although the notion of "current stream for device" is thread local
* (every OS thread has a separate current stream, as one might expect),
* the stream pool is global across all threads; stream 0 is always stream 0
* no matter which thread you use it on. Multiple threads can synchronize
* on the same stream. Although the CUDA documentation is not very clear
* on the matter, streams are thread safe; e.g., it is safe to enqueue
* a kernel on the same stream from two different threads.
*/
* Stream pool note.
*
* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
* are backed by cuStreams, but they use several pools to minimize the costs
* associated with creating, retaining, and destroying cuStreams.
*
* There are three pools per device, and a device's pools are lazily created.
*
* The first pool contains only the default stream. When the default stream
* is requested it's returned.
*
* The second pool is the "low priority" or "default priority" streams. In
* HIP builds there is no distinction between streams in this pool and streams
* in the third pool (below). There are 32 of these streams per device, and
* when a stream is requested one of these streams is returned round-robin.
* That is, the first stream requested is at index 0, the second at index 1...
* to index 31, then index 0 again.
*
* This means that if 33 low priority streams are requested, the first and
* last streams requested are actually the same stream (under the covers)
* and kernels enqueued on them cannot run concurrently.
*
* The third pool is the "high priority" streams. The third pool acts like
* the second pool except the streams are created with a higher priority.
*
* These pools suggest that stream users should prefer many short-lived streams,
* as the cost of acquiring and releasing streams is effectively zero. If
* many longer-lived streams are required in performance critical scenarios
* then the functionality here may need to be extended to allow, for example,
* "reserving" a subset of the pool so that other streams do not accidentally
* overlap the performance critical streams.
*
* Note: although the notion of "current stream for device" is thread local
* (every OS thread has a separate current stream, as one might expect),
* the stream pool is global across all threads; stream 0 is always stream 0
* no matter which thread you use it on. Multiple threads can synchronize
* on the same stream. Although the CUDA documentation is not very clear
* on the matter, streams are thread safe; e.g., it is safe to enqueue
* a kernel on the same stream from two different threads.
*/
namespace c10 {
namespace cuda {
@ -61,8 +61,7 @@ namespace cuda {
// functionality (conversion to cudaStream_t), and a guarantee that
// the wrapped c10::Stream really is a CUDA stream.
class C10_CUDA_API CUDAStream {
public:
public:
enum Unchecked { UNCHECKED };
/// Construct a CUDAStream from a Stream. This construction is checked,
@ -85,21 +84,31 @@ public:
}
/// Implicit conversion to cudaStream_t.
operator cudaStream_t() const { return stream(); }
operator cudaStream_t() const {
return stream();
}
/// Implicit conversion to Stream (a.k.a., forget that the stream is a
/// CUDA stream).
operator Stream() const { return unwrap(); }
operator Stream() const {
return unwrap();
}
/// Get the CUDA device index that this stream is associated with.
DeviceIndex device_index() const { return stream_.device_index(); }
DeviceIndex device_index() const {
return stream_.device_index();
}
/// Get the full Device that this stream is associated with. The Device
/// is guaranteed to be a CUDA device.
Device device() const { return Device(DeviceType::CUDA, device_index()); }
Device device() const {
return Device(DeviceType::CUDA, device_index());
}
/// Return the stream ID corresponding to this particular stream.
StreamId id() const { return stream_.id(); }
StreamId id() const {
return stream_.id();
}
bool query() const {
DeviceGuard guard{stream_.device()};
@ -120,17 +129,19 @@ public:
}
int priority() const {
DeviceGuard guard{stream_.device()};
int priority = 0;
C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
return priority;
DeviceGuard guard{stream_.device()};
int priority = 0;
C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
return priority;
}
/// Explicit conversion to cudaStream_t.
cudaStream_t stream() const;
/// Explicit conversion to Stream.
Stream unwrap() const { return stream_; }
Stream unwrap() const {
return stream_;
}
/// Reversibly pack a CUDAStream into a uint64_t representation. This may
/// be helpful when storing a CUDAStream in a C struct, where you cannot
@ -150,22 +161,24 @@ public:
}
static std::tuple<int, int> priority_range() {
// Note: this returns the range of priority **supported by PyTorch**, not
// the range of priority **supported by CUDA**. The former is a subset of
// the latter. Currently PyTorch only supports 0 and -1, which are "low" and
// "high" priority.
int least_priority, greatest_priority;
C10_CUDA_CHECK(
// Note: this returns the range of priority **supported by PyTorch**, not
// the range of priority **supported by CUDA**. The former is a subset of
// the latter. Currently PyTorch only supports 0 and -1, which are "low" and
// "high" priority.
int least_priority, greatest_priority;
C10_CUDA_CHECK(
cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
TORCH_INTERNAL_ASSERT(least_priority >= 0, "Unexpected CUDA stream priority range");
TORCH_INTERNAL_ASSERT(greatest_priority <= -1, "Unexpected CUDA stream priority range");
return std::make_tuple(0, -1);
TORCH_INTERNAL_ASSERT(
least_priority >= 0, "Unexpected CUDA stream priority range");
TORCH_INTERNAL_ASSERT(
greatest_priority <= -1, "Unexpected CUDA stream priority range");
return std::make_tuple(0, -1);
}
// Deleted for now; use CUDAEvent::block instead
// void synchronize_with(const CUDAEvent& event) const;
private:
private:
Stream stream_;
};
@ -214,13 +227,13 @@ TORCH_API void setCurrentCUDAStream(CUDAStream stream);
C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
} // namespace cuda
} // namespace at
} // namespace c10
namespace std {
template <>
struct hash<c10::cuda::CUDAStream> {
size_t operator()(c10::cuda::CUDAStream s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
template <>
struct hash<c10::cuda::CUDAStream> {
size_t operator()(c10::cuda::CUDAStream s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
} // namespace std

View File

@ -8,4 +8,6 @@ constexpr DeviceType CUDAGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(CUDA, CUDAGuardImpl);
}}} // namespace c10::cuda::detail
} // namespace impl
} // namespace cuda
} // namespace c10

View File

@ -4,10 +4,10 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>
@ -82,9 +82,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
// Event-related functions
void createEvent(
cudaEvent_t* cuda_event,
const EventFlag flag) const {
void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
// Maps PyTorch's Event::Flag to CUDA flag
auto cuda_flag = cudaEventDefault;
switch (flag) {
@ -103,10 +101,10 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override {
if (!event) return;
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
if (!event)
return;
auto cuda_event = static_cast<cudaEvent_t>(event);
int orig_device;
C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
@ -116,16 +114,17 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(
device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
CUDAStream cuda_stream{stream};
@ -135,7 +134,8 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
setDevice(stream.device());
// Creates the event (lazily)
if (!cuda_event) createEvent(&cuda_event, flag);
if (!cuda_event)
createEvent(&cuda_event, flag);
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
// Makes the void* point to the (possibly just allocated) CUDA event
*event = cuda_event;
@ -144,24 +144,24 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
setDevice(orig_device);
}
void block(
void* event,
const Stream& stream) const override {
if (!event) return;
void block(void* event, const Stream& stream) const override {
if (!event)
return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
CUDAStream cuda_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_CUDA_CHECK(cudaStreamWaitEvent(
cuda_stream,
cuda_event,
/*flags (must be zero)=*/ 0));
cuda_stream,
cuda_event,
/*flags (must be zero)=*/0));
setDevice(orig_device);
}
// May be called from any device
bool queryEvent(void* event) const override {
if (!event) return true;
if (!event)
return true;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
const cudaError_t err = cudaEventQuery(cuda_event);
if (err != cudaErrorNotReady) {
@ -170,12 +170,13 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
return (err == cudaSuccess);
}
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const Stream& stream) const override {
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
CUDAStream cuda_stream{stream};
CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
}
};
}}} // namespace c10::cuda::impl
} // namespace impl
} // namespace cuda
} // namespace c10

View File

@ -29,4 +29,6 @@ int c10_cuda_private_test() {
return 2;
}
}}} // namespace c10::cuda::impl
} // namespace impl
} // namespace cuda
} // namespace c10

View File

@ -8,4 +8,6 @@ namespace impl {
C10_CUDA_API int c10_cuda_test();
}}} /// namespace c10::cuda::impl
}
} // namespace cuda
} // namespace c10

View File

@ -102,17 +102,17 @@
// two pieces with confusing names?
// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we
// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker
// issues when linking big binaries. (https://github.com/pytorch/pytorch/issues/39968)
// We had two choices:
// issues when linking big binaries.
// (https://github.com/pytorch/pytorch/issues/39968) We had two choices:
// (1) Stop supporting so many GPU architectures
// (2) Do something else
// We chose #2 and decided to split the behemoth that was torch_cuda into two
// smaller libraries, one with most of the core kernel functions (torch_cuda_cu)
// and the other that had..well..everything else (torch_cuda_cpp). The idea was this:
// instead of linking our static libraries (like the hefty libcudnn_static.a) with
// another huge library, torch_cuda, and run into pesky relocation marker issues,
// we could link our static libraries to a smaller part of torch_cuda (torch_cuda_cpp)
// and avoid the issues.
// and the other that had..well..everything else (torch_cuda_cpp). The idea was
// this: instead of linking our static libraries (like the hefty
// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky
// relocation marker issues, we could link our static libraries to a smaller
// part of torch_cuda (torch_cuda_cpp) and avoid the issues.
// libtorch_cuda_cu.so
#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB
@ -128,7 +128,8 @@
#define TORCH_CUDA_CPP_API C10_IMPORT
#endif
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the same api)
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the
// same api)
#ifdef TORCH_CUDA_BUILD_MAIN_LIB
#define TORCH_CUDA_CPP_API C10_EXPORT
#define TORCH_CUDA_CU_API C10_EXPORT

View File

@ -24,16 +24,17 @@
#include <c10/macros/Export.h>
#if defined(__clang__)
#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero")))
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
#define __ubsan_ignore_signed_int_overflow__ __attribute__((no_sanitize("signed-integer-overflow")))
#define __ubsan_ignore_float_divide_by_zero__ \
__attribute__((no_sanitize("float-divide-by-zero")))
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
#define __ubsan_ignore_signed_int_overflow__ \
__attribute__((no_sanitize("signed-integer-overflow")))
#else
#define __ubsan_ignore_float_divide_by_zero__
#define __ubsan_ignore_undefined__
#define __ubsan_ignore_signed_int_overflow__
#define __ubsan_ignore_float_divide_by_zero__
#define __ubsan_ignore_undefined__
#define __ubsan_ignore_signed_int_overflow__
#endif
// Detect address sanitizer as some stuff doesn't work with it
#undef C10_ASAN_ENABLED
@ -57,7 +58,6 @@
#define C10_ASAN_ENABLED 0
#endif
// Disable the copy and assignment operator for a class. Note that this will
// disable the usage of the class in std containers.
#define C10_DISABLE_COPY_AND_ASSIGN(classname) \
@ -84,7 +84,6 @@
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
#endif
/// C10_NODISCARD - Warn if a type or return value is discarded.
// Technically, we should check if __cplusplus > 201402L here, because
@ -108,24 +107,25 @@
// - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support)
#define C10_NODISCARD
#if defined(__has_cpp_attribute)
# if __has_cpp_attribute(nodiscard)
# undef C10_NODISCARD
# define C10_NODISCARD [[nodiscard]]
# endif
#if __has_cpp_attribute(nodiscard)
#undef C10_NODISCARD
#define C10_NODISCARD [[nodiscard]]
#endif
// Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious
// error when __has_cpp_attribute is given a scoped attribute in C mode.
#elif __cplusplus && defined(__has_cpp_attribute)
# if __has_cpp_attribute(clang::warn_unused_result)
// TODO: It's possible this is still triggering https://github.com/pytorch/pytorch/issues/13118
// on Windows; if it is, better fix it.
# undef C10_NODISCARD
# define C10_NODISCARD [[clang::warn_unused_result]]
# endif
#if __has_cpp_attribute(clang::warn_unused_result)
// TODO: It's possible this is still triggering
// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better
// fix it.
#undef C10_NODISCARD
#define C10_NODISCARD [[clang::warn_unused_result]]
#endif
#endif
// suppress an unused variable.
#if defined(_MSC_VER) && !defined(__clang__)
#define C10_UNUSED __pragma(warning(suppress: 4100 4101))
#define C10_UNUSED __pragma(warning(suppress : 4100 4101))
#else
#define C10_UNUSED __attribute__((__unused__))
#endif //_MSC_VER
@ -135,16 +135,28 @@
// Simply define the namespace, in case a dependent library want to refer to
// the c10 namespace but not any nontrivial files.
namespace c10 {} // namespace c10
namespace c10 { namespace cuda {} }
namespace c10 { namespace hip {} }
namespace c10 {
namespace cuda {}
} // namespace c10
namespace c10 {
namespace hip {}
} // namespace c10
// Since C10 is the core library for caffe2 (and aten), we will simply reroute
// all abstractions defined in c10 to be available in caffe2 as well.
// This is only for backwards compatibility. Please use the symbols from the
// c10 namespace where possible.
namespace caffe2 { using namespace c10; }
namespace at { using namespace c10; }
namespace at { namespace cuda { using namespace c10::cuda; }}
namespace caffe2 {
using namespace c10;
}
namespace at {
using namespace c10;
}
namespace at {
namespace cuda {
using namespace c10::cuda;
}
} // namespace at
// WARNING!!! THIS IS A GIANT HACK!!!
// This line means you cannot simultaneously include c10/hip
@ -154,7 +166,11 @@ namespace at { namespace cuda { using namespace c10::cuda; }}
// from at::cuda. This namespace makes that happen. When
// HIPIFY is no longer out-of-place, we can switch the cuda
// here to hip and everyone is happy.
namespace at { namespace cuda { using namespace c10::hip; }}
namespace at {
namespace cuda {
using namespace c10::hip;
}
} // namespace at
// C10_LIKELY/C10_UNLIKELY
//
@ -169,11 +185,11 @@ namespace at { namespace cuda { using namespace c10::hip; }}
// without it.
//
#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
#define C10_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
#define C10_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
#define C10_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
#define C10_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
#else
#define C10_LIKELY(expr) (expr)
#define C10_UNLIKELY(expr) (expr)
#define C10_LIKELY(expr) (expr)
#define C10_UNLIKELY(expr) (expr)
#endif
/// C10_NOINLINE - Functions whose declaration is annotated with this will not
@ -209,11 +225,13 @@ namespace at { namespace cuda { using namespace c10::hip; }}
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
// constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5),
// 1536 for Geforce Ampere (8.6),
// and 2048 for all other architectures. You'll get warnings if you exceed these constants.
// Hence, the following macros adjust the input values from the user to resolve potential warnings.
// constants from
// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
// The maximum number of threads per multiprocessor is 1024 for Turing
// architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other
// architectures. You'll get warnings if you exceed these constants. Hence, the
// following macros adjust the input values from the user to resolve potential
// warnings.
#if __CUDA_ARCH__ == 750
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;
#elif __CUDA_ARCH__ == 860
@ -223,25 +241,39 @@ constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;
#endif
// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block size.
// 256 is a good number for this fallback and should give good occupancy and
// versatility across all architectures.
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block
// size. 256 is a good number for this fallback and should give good occupancy
// and versatility across all architectures.
constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it
// turns out that although __launch_bounds__ can take constexpr, it
// can't take a constexpr that has anything to do with templates.
// Currently we use launch_bounds that depend on template arguments in
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK and
// C10_MIN_BLOCKS_PER_SM are kept as macros.
// Suppose you were planning to write __launch_bounds__(a, b), based on your performance tuning on a modern GPU.
// Instead, you should write __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK
// and C10_MIN_BLOCKS_PER_SM are kept as macros.
// Suppose you were planning to write __launch_bounds__(a, b), based on your
// performance tuning on a modern GPU. Instead, you should write
// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
// which will also properly respect limits on old architectures.
#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
#define C10_MAX_THREADS_PER_BLOCK(val) \
(((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \
: CUDA_THREADS_PER_BLOCK_FALLBACK)
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \
((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \
? (blocks_per_sm) \
: ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / \
(threads_per_block))))
// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
#define C10_LAUNCH_BOUNDS_0 \
__launch_bounds__( \
256, 4) // default launch bounds that should give good occupancy and
// versatility across all architectures.
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \
__launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \
__launch_bounds__( \
(C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \
(C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
#else
#define C10_HOST_DEVICE
#define C10_HOST
@ -273,14 +305,12 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
#elif defined(_MSC_VER)
#if defined(NDEBUG)
extern "C" {
C10_IMPORT
C10_IMPORT
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__)
__host__ __device__
__host__ __device__
#endif // __CUDA_ARCH__
void _wassert(
wchar_t const* _Message,
wchar_t const* _File,
unsigned _Line);
void
_wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line);
}
#endif
#define CUDA_KERNEL_ASSERT(cond) \
@ -304,8 +334,8 @@ __host__ __device__
#endif // NDEBUG
#define CUDA_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
__assert_fail(#cond, __FILE__, static_cast<unsigned int>(__LINE__), \
__func__); \
__assert_fail( \
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#endif // __APPLE__

View File

@ -19,16 +19,15 @@
* See below for more information.
* Why?
* It has been observed that some mobile platforms, such as pixel 3, return
* memory aggressively to the system. This results in page faults in some cases
* and ends up hurting performance. This caching allocator aims to address that.
* Furthermore it also allows users to specify their own allocator by implementing
* allocate/free virtual interfaces.
* What are the cons?
* There are some cons that were observed where use of caching allocator led to
* worse performance on some platforms. Reason being that the caching mechanism
* used by this allocator left us worse off compared to the corresponding platform's
* tuned memory allocator. In that case it seemed better to not use this allocator.
* Note there are some ideas to fix this in the works.
* memory aggressively to the system. This results in page faults in some
* cases and ends up hurting performance. This caching allocator aims to address
* that. Furthermore it also allows users to specify their own allocator by
* implementing allocate/free virtual interfaces. What are the cons? There are
* some cons that were observed where use of caching allocator led to worse
* performance on some platforms. Reason being that the caching mechanism used
* by this allocator left us worse off compared to the corresponding platform's
* tuned memory allocator. In that case it seemed better to not use this
* allocator. Note there are some ideas to fix this in the works.
*
* Usage:
* Usage pattern:
@ -53,42 +52,44 @@ class C10_API CPUCachingAllocator {
* What it does not do:
* No speculative allocation for any future allocations.
*/
private:
inline void* allocate_and_cache(const size_t bytes);
void free_cached();
protected:
// Invariants.
// 1. If memory is ever allocated via this allocator then
// the pointer will exist in allocation_map_, unless the allocator
// returned the memory to OS via free_cached.
// 1.1. Therefore even when the said memory is "freed" via this
// allocator (and thus cached), it will continue to stay
// in allocation_map_. Furthermore it will also exist in
// available_map_. Thus an allocated memory pointer can be in both
// allocation_map_ and available_map_ simultaneously.
// 2. Memory pointer maybe removed from allocation_map_, when it
// is freed outside of the scope of this allocator, but was allocated
// by this allocator.
// 3. Available map only contains that memory which was allocated
// by this allocator and subsequently freed by this allocator.
// As a result of above invariants, allocated memory ptr cannot be in
// available_map_ unless it is in allocation_map_ as well.
ska::flat_hash_map<size_t, c10::SmallVector<void*, 16>> available_map_;
static ska::flat_hash_map<void*, size_t> allocation_map_;
// Since allocation_map, which is a global instance, is mutated/read via
// all public APIs we need a global mutex.
static std::mutex mutex_;
public:
static void record_free(void* ptr);
virtual ~CPUCachingAllocator();
// Checks the cache to see if allocation of size bytes can be found.
// If so return cached memory, else
// allocates memory, records it for caching and returns.
virtual void* allocate(const size_t bytes);
// Checks if the memory being freed is was marked for allocation by
// an earlier call to allocate. If so cache the allocation.
// Otherwise free.
virtual void free(void* ptr);
private:
inline void* allocate_and_cache(const size_t bytes);
void free_cached();
protected:
// Invariants.
// 1. If memory is ever allocated via this allocator then
// the pointer will exist in allocation_map_, unless the allocator
// returned the memory to OS via free_cached.
// 1.1. Therefore even when the said memory is "freed" via this
// allocator (and thus cached), it will continue to stay
// in allocation_map_. Furthermore it will also exist in
// available_map_. Thus an allocated memory pointer can be in both
// allocation_map_ and available_map_ simultaneously.
// 2. Memory pointer maybe removed from allocation_map_, when it
// is freed outside of the scope of this allocator, but was allocated
// by this allocator.
// 3. Available map only contains that memory which was allocated
// by this allocator and subsequently freed by this allocator.
// As a result of above invariants, allocated memory ptr cannot be in
// available_map_ unless it is in allocation_map_ as well.
ska::flat_hash_map<size_t, c10::SmallVector<void*, 16>> available_map_;
static ska::flat_hash_map<void*, size_t> allocation_map_;
// Since allocation_map, which is a global instance, is mutated/read via
// all public APIs we need a global mutex.
static std::mutex mutex_;
public:
static void record_free(void* ptr);
virtual ~CPUCachingAllocator();
// Checks the cache to see if allocation of size bytes can be found.
// If so return cached memory, else
// allocates memory, records it for caching and returns.
virtual void* allocate(const size_t bytes);
// Checks if the memory being freed is was marked for allocation by
// an earlier call to allocate. If so cache the allocation.
// Otherwise free.
virtual void free(void* ptr);
};
CPUCachingAllocator* GetDefaultCPUCachingAllocator();
@ -97,11 +98,12 @@ bool ThreadLocalCachingAllocatorEnabled();
CPUCachingAllocator* GetThreadLocalCachingAllocator();
class C10_API WithCPUCachingAllocatorGuard {
public:
WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator);
~WithCPUCachingAllocatorGuard();
private:
CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr};
public:
WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator);
~WithCPUCachingAllocatorGuard();
private:
CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr};
};
} // namespace c10

View File

@ -19,27 +19,23 @@ struct MemBlock {
}
};
enum class EventType {
Allocate = 0,
Free,
Invalid
};
enum class EventType { Allocate = 0, Free, Invalid };
struct MemEvent {
uint64_t time;
uint64_t allocation_id;
uint64_t size;
EventType type{EventType::Invalid};
MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e) :
time(t), allocation_id(id), size(s), type(e) {}
MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e)
: time(t), allocation_id(id), size(s), type(e) {}
};
bool overlaps(const MemBlock& a, const MemBlock& b) {
// two blocks dont overlap if
// |---a--------|--------------b--------|
// strat_a end_a <= start_b end_b
return
!((a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
return !(
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
}
bool validate_allocation_plan(
@ -72,7 +68,8 @@ bool validate_allocation_plan(
allocations.emplace(mem_block);
} else if (event.type == EventType::Free) {
auto it = allocations.find(mem_block);
TORCH_CHECK((*it).end_offset == end_offset,
TORCH_CHECK(
(*it).end_offset == end_offset,
"Enf offset of allocation being freed must match the one recorded.");
TORCH_CHECK(
it != allocations.end(),
@ -98,13 +95,15 @@ std::vector<MemEvent> create_and_sort_mem_events(
continue;
}
events.emplace_back(i, i, allocation_sizes[i], EventType::Allocate);
events.emplace_back(allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free);
events.emplace_back(
allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free);
}
std::sort(
events.begin(),
events.end(),
[](const MemEvent& a,
const MemEvent& b) -> bool {return a.time < b.time;});
[](const MemEvent& a, const MemEvent& b) -> bool {
return a.time < b.time;
});
return events;
}
@ -132,8 +131,10 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
std::map<uint64_t, uint64_t> free_size_to_offset;
// This provides fast lookup when we want to insert freed block
// back, especially when we want to merge blocks.
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator> free_start_offset_to_size_iter;
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator> free_end_offset_to_size_iter;
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator>
free_start_offset_to_size_iter;
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator>
free_end_offset_to_size_iter;
// Upon free end_ptr = offset + size
// If end_ptr exists merge freed allocation
// Also find corresponding offset in size_to_offset
@ -146,7 +147,8 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
std::vector<uint64_t> allocation_offsets(
allocation_sizes.size(), std::numeric_limits<uint64_t>::max());
auto mem_events = create_and_sort_mem_events(allocation_sizes, allocation_lifetimes);
auto mem_events =
create_and_sort_mem_events(allocation_sizes, allocation_lifetimes);
uint64_t max_offset{0};
for (const auto& mem_event : mem_events) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -221,13 +223,14 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
free_start_offset_to_size_iter.erase(freed_offset);
}
auto freed_block_it =
free_size_to_offset.emplace(freed_size, freed_offset).first;
free_size_to_offset.emplace(freed_size, freed_offset).first;
free_start_offset_to_size_iter.emplace(freed_offset, freed_block_it);
free_end_offset_to_size_iter.emplace(
freed_offset + freed_size, freed_block_it);
}
}
TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets),
TORCH_CHECK(
validate_allocation_plan(mem_events, allocation_offsets),
"ProfilingAllocator: Allocation plan invalid.");
return allocation_offsets;
}
@ -241,7 +244,8 @@ void AllocationPlan::clear() {
}
void AllocationPlanner::record_allocation(
const uint64_t size, const void* ptr) {
const uint64_t size,
const void* ptr) {
if (validation_mode_) {
validation_success = validation_success && validate_allocation(size, ptr);
return;
@ -264,13 +268,15 @@ void AllocationPlanner::record_free(const void* ptr) {
return;
}
auto id = it->second;
TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(),
TORCH_CHECK(
id < allocation_plan_->allocation_lifetimes.size(),
"Allocation must have been recorded during record_allocation.");
allocation_plan_->allocation_lifetimes[id] = allocation_id_;
}
bool AllocationPlanner::validate_allocation(
const uint64_t size, const void* ptr) {
const uint64_t size,
const void* ptr) {
if (allocation_id_ >= allocation_plan_->allocation_sizes.size() ||
allocation_plan_->allocation_sizes[allocation_id_] != size) {
TORCH_WARN(
@ -286,7 +292,7 @@ bool AllocationPlanner::validate_allocation(
return false;
}
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_id_++;
return true;
}
@ -298,24 +304,27 @@ bool AllocationPlanner::validate_free(const void* ptr) {
return true;
}
auto id = (*it).second;
TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(),
TORCH_CHECK(
id < allocation_plan_->allocation_lifetimes.size(),
"Allocation must have been recorded during validate_allocation.");
auto lifetime_id = allocation_plan_->allocation_lifetimes[id];
return (lifetime_id == allocation_id_);
}
void AllocationPlanner::formulate_plan() {
allocation_plan_->allocation_offsets =
formulate_greedy_allocation_plan(
allocation_plan_->allocation_sizes, allocation_plan_->allocation_lifetimes);
allocation_plan_->allocation_offsets = formulate_greedy_allocation_plan(
allocation_plan_->allocation_sizes,
allocation_plan_->allocation_lifetimes);
allocation_plan_->total_size = 0;
for (const auto i : c10::irange(allocation_plan_->allocation_sizes.size())) {
if (allocation_plan_->allocation_lifetimes[i] ==
std::numeric_limits<uint64_t>::max()) {
continue;
}
auto limit = allocation_plan_->allocation_offsets[i] + allocation_plan_->allocation_sizes[i];
allocation_plan_->total_size = std::max(allocation_plan_->total_size, limit);
auto limit = allocation_plan_->allocation_offsets[i] +
allocation_plan_->allocation_sizes[i];
allocation_plan_->total_size =
std::max(allocation_plan_->total_size, limit);
}
}
@ -344,7 +353,8 @@ void CPUProfilingAllocator::unset_plan() {
}
void* CPUProfilingAllocator::allocate(const size_t bytes) {
TORCH_CHECK(bytes == plan_->allocation_sizes[allocation_id_],
TORCH_CHECK(
bytes == plan_->allocation_sizes[allocation_id_],
"Got allocation request that does not match with the plan.");
if (plan_->allocation_lifetimes[allocation_id_] ==
std::numeric_limits<uint64_t>::max()) {
@ -352,9 +362,8 @@ void* CPUProfilingAllocator::allocate(const size_t bytes) {
allocation_id_++;
return c10::alloc_cpu(bytes);
}
void* ptr =
reinterpret_cast<uint8_t*>(blob_) +
plan_->allocation_offsets[allocation_id_];
void* ptr = reinterpret_cast<uint8_t*>(blob_) +
plan_->allocation_offsets[allocation_id_];
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_id_++;
return ptr;
@ -364,8 +373,8 @@ void CPUProfilingAllocator::free(void* const ptr) {
auto it = allocation_ptr_to_id_.find(ptr);
if (it == allocation_ptr_to_id_.end()) {
// Either
// 1. Allocation that was made outside the validation scope is being freed here
// or
// 1. Allocation that was made outside the validation scope is being freed
// here or
// 2. Allocation that is not managed by profiling allocator is being freed.
// Example of the second type
// Tensor out;
@ -380,7 +389,8 @@ void CPUProfilingAllocator::free(void* const ptr) {
return;
}
auto id = it->second;
TORCH_CHECK(id < plan_->allocation_lifetimes.size(),
TORCH_CHECK(
id < plan_->allocation_lifetimes.size(),
"Freeing allocation that is not accordingly to the plan.");
auto lifetime_id = plan_->allocation_lifetimes[id];
TORCH_CHECK(
@ -397,10 +407,10 @@ CPUProfilingAllocator::~CPUProfilingAllocator() {
c10::free_cpu(blob_);
}
WithProfileAllocationsGuard::WithProfileAllocationsGuard(
AllocationPlan* plan) {
WithProfileAllocationsGuard::WithProfileAllocationsGuard(AllocationPlan* plan) {
// Nesting of allocation profiling does not seem meaningful.
TORCH_CHECK(allocation_planner == nullptr,
TORCH_CHECK(
allocation_planner == nullptr,
"Nesting profiling allocations is not supported.");
planner_ = std::make_unique<AllocationPlanner>(plan);
planner_->clear();
@ -413,9 +423,11 @@ WithProfileAllocationsGuard::~WithProfileAllocationsGuard() {
}
WithValidateAllocationPlanGuard::WithValidateAllocationPlanGuard(
AllocationPlan* plan, bool* success) {
AllocationPlan* plan,
bool* success) {
// Nesting of allocation profiling does not seem meaningful.
TORCH_CHECK(allocation_planner == nullptr,
TORCH_CHECK(
allocation_planner == nullptr,
"Nesting profiling allocations is not supported.");
planner_ = std::make_unique<AllocationPlanner>(plan, true);
success_ = success;
@ -432,9 +444,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner() {
}
WithProfilingAllocatorGuard::WithProfilingAllocatorGuard(
CPUProfilingAllocator* allocator, const AllocationPlan* plan) {
CPUProfilingAllocator* allocator,
const AllocationPlan* plan) {
// Nesting of profiling allocator is not supported.
TORCH_CHECK(profiling_allocator == nullptr,
TORCH_CHECK(
profiling_allocator == nullptr,
"Nesting profiling allocators is not supported.");
profiling_allocator = allocator;
profiling_allocator->set_plan(plan);

View File

@ -16,29 +16,30 @@ namespace c10 {
* Given a sequence of allocations in a thread, AllocationPlan records
* 1. size of each allocation
* 2. Lifetime of each allocation.
* 3. allocation offsets: Memory offset for each allocation in a single blob of memory
* 3. allocation offsets: Memory offset for each allocation in a single blob of
* memory
* 4. Total size of a blob of memory required to satisfy all the allocations.
*/
class C10_API AllocationPlan {
private:
// Records size of each allocation by their sequential allocation ids.
std::vector<uint64_t> allocation_sizes;
// This maps one allocation id (X) to another allocation id (Y).
// Allocation X is alive until allocation Y. From allocation Y onwards
// allocation X is not referenced.
// Thus Y is the id of the first allocation after X is freed.
// NB: When an allocation is recorded, along with recording its size,
// we also set the lifetime to be numeric_limits::max()
// This is to track allocations that are made during the scope of
// profiling but were not freed until after the scope ended.
// Such allocations are not managed by profiling allocator.
std::vector<uint64_t> allocation_lifetimes;
// Maps an allocation to some offset in a blob of memory.
std::vector<uint64_t> allocation_offsets;
uint64_t total_size{0};
void clear();
friend class AllocationPlanner;
friend class CPUProfilingAllocator;
private:
// Records size of each allocation by their sequential allocation ids.
std::vector<uint64_t> allocation_sizes;
// This maps one allocation id (X) to another allocation id (Y).
// Allocation X is alive until allocation Y. From allocation Y onwards
// allocation X is not referenced.
// Thus Y is the id of the first allocation after X is freed.
// NB: When an allocation is recorded, along with recording its size,
// we also set the lifetime to be numeric_limits::max()
// This is to track allocations that are made during the scope of
// profiling but were not freed until after the scope ended.
// Such allocations are not managed by profiling allocator.
std::vector<uint64_t> allocation_lifetimes;
// Maps an allocation to some offset in a blob of memory.
std::vector<uint64_t> allocation_offsets;
uint64_t total_size{0};
void clear();
friend class AllocationPlanner;
friend class CPUProfilingAllocator;
};
/*
@ -46,43 +47,45 @@ class C10_API AllocationPlan {
* used to establish lifetime of allocations.
*/
class C10_API AllocationPlanner {
private:
AllocationPlan* allocation_plan_{nullptr};
// Maps allocated ptr to its allocation id.
// This is used when freeing the memory to lookup the allocation id
// in order to establish the lifetime of a particular allocation.
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
uint64_t allocation_id_{0};
bool validation_mode_{false};
private:
AllocationPlan* allocation_plan_{nullptr};
// Maps allocated ptr to its allocation id.
// This is used when freeing the memory to lookup the allocation id
// in order to establish the lifetime of a particular allocation.
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
uint64_t allocation_id_{0};
bool validation_mode_{false};
bool validate_allocation(const uint64_t size, const void* ptr);
bool validate_free(const void* ptr);
public:
bool validation_success{true};
bool validate_allocation(const uint64_t size, const void* ptr);
bool validate_free(const void* ptr);
AllocationPlanner() = delete;
AllocationPlanner(AllocationPlan* plan, bool validate = false) :
allocation_plan_(plan), validation_mode_(validate) {}
void record_allocation(const uint64_t size, const void* ptr);
void record_free(const void* ptr);
void formulate_plan();
void clear();
public:
bool validation_success{true};
AllocationPlanner() = delete;
AllocationPlanner(AllocationPlan* plan, bool validate = false)
: allocation_plan_(plan), validation_mode_(validate) {}
void record_allocation(const uint64_t size, const void* ptr);
void record_free(const void* ptr);
void formulate_plan();
void clear();
};
// NOT THREAD SAFE profiling allocator.
class C10_API CPUProfilingAllocator {
private:
const AllocationPlan* plan_{nullptr};
uint64_t allocation_id_{0};
uint64_t current_size_{0};
void* blob_{nullptr};
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
public:
~CPUProfilingAllocator();
void set_plan(const AllocationPlan* plan);
void unset_plan();
void* allocate(const size_t bytes);
void free(void* const ptr);
private:
const AllocationPlan* plan_{nullptr};
uint64_t allocation_id_{0};
uint64_t current_size_{0};
void* blob_{nullptr};
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
public:
~CPUProfilingAllocator();
void set_plan(const AllocationPlan* plan);
void unset_plan();
void* allocate(const size_t bytes);
void free(void* const ptr);
};
/*
@ -95,11 +98,12 @@ class C10_API CPUProfilingAllocator {
* plan now contains allocation plan.
*/
class C10_API WithProfileAllocationsGuard {
public:
WithProfileAllocationsGuard(AllocationPlan* plan);
~WithProfileAllocationsGuard();
private:
std::unique_ptr<AllocationPlanner> planner_;
public:
WithProfileAllocationsGuard(AllocationPlan* plan);
~WithProfileAllocationsGuard();
private:
std::unique_ptr<AllocationPlanner> planner_;
};
/*
@ -115,12 +119,13 @@ class C10_API WithProfileAllocationsGuard {
* else for some inputs allocation pattern changed.
*/
class C10_API WithValidateAllocationPlanGuard {
public:
WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success);
~WithValidateAllocationPlanGuard();
private:
std::unique_ptr<AllocationPlanner> planner_;
bool* success_;
public:
WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success);
~WithValidateAllocationPlanGuard();
private:
std::unique_ptr<AllocationPlanner> planner_;
bool* success_;
};
AllocationPlanner* GetThreadLocalAllocationPlanner();
@ -138,10 +143,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner();
* }
*/
class C10_API WithProfilingAllocatorGuard {
public:
WithProfilingAllocatorGuard(
CPUProfilingAllocator* allocator, const AllocationPlan* plan);
~WithProfilingAllocatorGuard();
public:
WithProfilingAllocatorGuard(
CPUProfilingAllocator* allocator,
const AllocationPlan* plan);
~WithProfilingAllocatorGuard();
};
CPUProfilingAllocator* GetThreadLocalProfilingAllocator();

View File

@ -5,63 +5,71 @@ namespace test_is_compile_time_function_pointer {
static_assert(!c10::is_compile_time_function_pointer<void()>::value, "");
void dummy() {}
static_assert(c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value, "");
}
static_assert(
c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value,
"");
} // namespace test_is_compile_time_function_pointer
namespace test_access_through_type {
void dummy() {}
using dummy_ptr = TORCH_FN_TYPE(dummy);
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
static_assert(dummy_ptr::func_ptr() == &dummy, "");
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
}
void dummy() {}
using dummy_ptr = TORCH_FN_TYPE(dummy);
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
static_assert(dummy_ptr::func_ptr() == &dummy, "");
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
} // namespace test_access_through_type
namespace test_access_through_value {
void dummy() {}
constexpr auto dummy_ptr = TORCH_FN(dummy);
static_assert(dummy_ptr.func_ptr() == &dummy, "");
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
}
void dummy() {}
constexpr auto dummy_ptr = TORCH_FN(dummy);
static_assert(dummy_ptr.func_ptr() == &dummy, "");
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
} // namespace test_access_through_value
namespace test_access_through_type_also_works_if_specified_as_pointer {
void dummy() {}
using dummy_ptr = TORCH_FN_TYPE(&dummy);
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
static_assert(dummy_ptr::func_ptr() == &dummy, "");
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
}
void dummy() {}
using dummy_ptr = TORCH_FN_TYPE(&dummy);
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
static_assert(dummy_ptr::func_ptr() == &dummy, "");
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
} // namespace test_access_through_type_also_works_if_specified_as_pointer
namespace test_access_through_value_also_works_if_specified_as_pointer {
void dummy() {}
constexpr auto dummy_ptr = TORCH_FN(&dummy);
static_assert(dummy_ptr.func_ptr() == &dummy, "");
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
}
void dummy() {}
constexpr auto dummy_ptr = TORCH_FN(&dummy);
static_assert(dummy_ptr.func_ptr() == &dummy, "");
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
} // namespace test_access_through_value_also_works_if_specified_as_pointer
namespace test_run_through_type {
int add(int a, int b) {return a + b;}
using Add = TORCH_FN_TYPE(add);
template<class Func> struct Executor {
int execute(int a, int b) {
return Func::func_ptr()(a, b);
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) {
Executor<Add> executor;
EXPECT_EQ(3, executor.execute(1, 2));
}
int add(int a, int b) {
return a + b;
}
using Add = TORCH_FN_TYPE(add);
template <class Func>
struct Executor {
int execute(int a, int b) {
return Func::func_ptr()(a, b);
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) {
Executor<Add> executor;
EXPECT_EQ(3, executor.execute(1, 2));
}
} // namespace test_run_through_type
namespace test_run_through_value {
int add(int a, int b) {return a + b;}
template<class Func> int execute(Func, int a, int b) {
return Func::func_ptr()(a, b);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) {
EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
}
int add(int a, int b) {
return a + b;
}
template <class Func>
int execute(Func, int a, int b) {
return Func::func_ptr()(a, b);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) {
EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
}
} // namespace test_run_through_value

View File

@ -9,7 +9,8 @@ using namespace c10;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(DispatchKeySet, Empty) {
DispatchKeySet empty_set;
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
ASSERT_FALSE(empty_set.has(tid));
}
@ -21,7 +22,8 @@ TEST(DispatchKeySet, Empty) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(DispatchKeySet, Singleton) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
@ -37,8 +39,11 @@ TEST(DispatchKeySet, Singleton) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(DispatchKeySet, Doubleton) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
for (uint8_t j = i + 1; j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); j++) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
for (uint8_t j = i + 1;
j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
j++) {
ASSERT_LT(i, j);
auto tid1 = static_cast<DispatchKey>(i);
auto tid2 = static_cast<DispatchKey>(j);
@ -46,7 +51,7 @@ TEST(DispatchKeySet, Doubleton) {
ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2));
ASSERT_TRUE(doub.has(tid1));
ASSERT_TRUE(doub.has(tid2));
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
}
}
}
@ -54,7 +59,8 @@ TEST(DispatchKeySet, Doubleton) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(DispatchKeySet, Full) {
DispatchKeySet full(DispatchKeySet::FULL);
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
ASSERT_TRUE(full.has(tid));
}
@ -103,13 +109,12 @@ TEST(DispatchKeySet, IteratorFull) {
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(DispatchKeySet, IteratorRangeFull) {
DispatchKeySet full_set(DispatchKeySet::FULL);
uint8_t i = 0;
for (DispatchKey dispatch_key: full_set) {
for (DispatchKey dispatch_key : full_set) {
i++;
ASSERT_TRUE(dispatch_key == static_cast<DispatchKey>(i));
}
@ -126,17 +131,20 @@ TEST(DispatchKeySet, SpecificKeys) {
static_cast<DispatchKey>(10),
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_cast<DispatchKey>(15),
});
});
std::unordered_set<DispatchKey> visited_keys;
for(DispatchKey key: keyset) {
for (DispatchKey key : keyset) {
visited_keys.insert(key);
}
ASSERT_EQ(visited_keys.size(), 3);
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -145,9 +153,8 @@ TEST(DispatchKeySet, FailAtEndIterator) {
uint64_t raw_repr = full_set.raw_repr();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(DispatchKeySet::iterator(
&raw_repr,
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1
),
c10::Error);
EXPECT_THROW(
DispatchKeySet::iterator(
&raw_repr, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1),
c10::Error);
}

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/FakeGuardImpl.h>
#include <c10/core/impl/InlineDeviceGuard.h>
using namespace c10;
using namespace c10::impl;
@ -55,8 +55,8 @@ TEST(InlineDeviceGuard, Constructor) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlineDeviceGuard, ConstructorError) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_ANY_THROW(InlineDeviceGuard<FakeGuardImpl<DeviceType::CUDA>>
g(Device(DeviceType::HIP, 1)));
EXPECT_ANY_THROW(InlineDeviceGuard<FakeGuardImpl<DeviceType::CUDA>> g(
Device(DeviceType::HIP, 1)));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -110,7 +110,8 @@ TEST(InlineDeviceGuard, SetIndex) {
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i2);
}
// -- InlineOptionalDeviceGuard --------------------------------------------------
// -- InlineOptionalDeviceGuard
// --------------------------------------------------
using MaybeTestGuard = InlineOptionalDeviceGuard<TestGuardImpl>;

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/core/impl/FakeGuardImpl.h>
#include <c10/core/impl/InlineStreamGuard.h>
using namespace c10;
using namespace c10::impl;
@ -40,7 +40,6 @@ TEST(InlineStreamGuard, Constructor) {
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlineStreamGuard, ResetStreamSameSameDevice) {
TestGuardImpl::setDeviceIndex(0);
@ -101,7 +100,8 @@ TEST(InlineStreamGuard, ResetStreamDifferentDevice) {
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
}
// -- OptionalInlineStreamGuard -------------------------------------------------------
// -- OptionalInlineStreamGuard
// -------------------------------------------------------
using OptionalTestGuard = InlineOptionalStreamGuard<TestGuardImpl>;
@ -180,7 +180,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) {
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
}
// -- InlineMultiStreamGuard -------------------------------------------------------
// -- InlineMultiStreamGuard
// -------------------------------------------------------
using MultiTestGuard = InlineMultiStreamGuard<TestGuardImpl>;

View File

@ -6,12 +6,16 @@
using namespace c10;
using namespace c10::impl;
static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) {
EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match";
static void checkData(
const SizesAndStrides& sz,
IntArrayRef sizes,
IntArrayRef strides) {
EXPECT_EQ(sizes.size(), strides.size())
<< "bad test case: size() of sizes and strides don't match";
EXPECT_EQ(sz.size(), sizes.size());
int idx = 0;
for (auto x: sizes) {
for (auto x : sizes) {
EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx;
EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx;
EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx;
@ -21,7 +25,7 @@ static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef
EXPECT_EQ(sz.sizes_arrayref(), sizes);
idx = 0;
for (auto x: strides) {
for (auto x : strides) {
EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx;
EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx;
EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx;

View File

@ -6,90 +6,92 @@ using c10::guts::to_array;
namespace {
namespace test_equals {
static_assert(array<int, 0>{{}} == array<int, 0>{{}}, "");
static_assert(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 4}}, "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{1, 3, 4}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 1, 4}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 1}}), "");
}
static_assert(array<int, 0>{{}} == array<int, 0>{{}}, "");
static_assert(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 4}}, "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{1, 3, 4}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 1, 4}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 1}}), "");
} // namespace test_equals
namespace test_notequals {
static_assert(!(array<int, 0>{{}} != array<int, 0>{{}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 4}}), "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{1, 3, 4}}, "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 1, 4}}, "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 1}}, "");
}
static_assert(!(array<int, 0>{{}} != array<int, 0>{{}}), "");
static_assert(!(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 4}}), "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{1, 3, 4}}, "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 1, 4}}, "");
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 1}}, "");
} // namespace test_notequals
namespace test_lessthan {
static_assert(!(array<int, 0>{{}} < array<int, 0>{{}}), "");
static_assert(!(array<int, 1>{{2}} < array<int, 1>{{1}}), "");
static_assert(array<int, 1>{{1}} < array<int, 1>{{2}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{2, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{0, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 3, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 1, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 4}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 2}}), "");
}
static_assert(!(array<int, 0>{{}} < array<int, 0>{{}}), "");
static_assert(!(array<int, 1>{{2}} < array<int, 1>{{1}}), "");
static_assert(array<int, 1>{{1}} < array<int, 1>{{2}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{2, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{0, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 3, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 1, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 4}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 2}}), "");
} // namespace test_lessthan
namespace test_greaterthan {
static_assert(!(array<int, 0>{{}} > array<int, 0>{{}}), "");
static_assert(!(array<int, 1>{{1}} > array<int, 1>{{2}}), "");
static_assert(array<int, 1>{{2}} > array<int, 1>{{1}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{2, 2, 3}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{0, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 3, 3}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 1, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 4}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 2}} > array<int, 3>{{1, 2, 3}}), "");
}
static_assert(!(array<int, 0>{{}} > array<int, 0>{{}}), "");
static_assert(!(array<int, 1>{{1}} > array<int, 1>{{2}}), "");
static_assert(array<int, 1>{{2}} > array<int, 1>{{1}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{2, 2, 3}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{0, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 3, 3}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 1, 3}} > array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 4}} > array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 2}} > array<int, 3>{{1, 2, 3}}), "");
} // namespace test_greaterthan
namespace test_lessequals {
static_assert(array<int, 0>{{}} <= array<int, 0>{{}}, "");
static_assert(!(array<int, 1>{{2}} <= array<int, 1>{{1}}), "");
static_assert(array<int, 1>{{1}} <= array<int, 1>{{2}}, "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 3}}, "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{2, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{0, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 3, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 1, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 4}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 2}}), "");
}
static_assert(array<int, 0>{{}} <= array<int, 0>{{}}, "");
static_assert(!(array<int, 1>{{2}} <= array<int, 1>{{1}}), "");
static_assert(array<int, 1>{{1}} <= array<int, 1>{{2}}, "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 3}}, "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{2, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{0, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 3, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 1, 3}}), "");
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 4}}, "");
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 2}}), "");
} // namespace test_lessequals
namespace test_greaterequals {
static_assert(array<int, 0>{{}} >= array<int, 0>{{}}, "");
static_assert(!(array<int, 1>{{1}} >= array<int, 1>{{2}}), "");
static_assert(array<int, 1>{{2}} >= array<int, 1>{{1}}, "");
static_assert(array<int, 3>{{1, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(array<int, 3>{{2, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{0, 2, 3}} >= array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 3, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 1, 3}} >= array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 4}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 2}} >= array<int, 3>{{1, 2, 3}}), "");
}
static_assert(array<int, 0>{{}} >= array<int, 0>{{}}, "");
static_assert(!(array<int, 1>{{1}} >= array<int, 1>{{2}}), "");
static_assert(array<int, 1>{{2}} >= array<int, 1>{{1}}, "");
static_assert(array<int, 3>{{1, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(array<int, 3>{{2, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{0, 2, 3}} >= array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 3, 3}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 1, 3}} >= array<int, 3>{{1, 2, 3}}), "");
static_assert(array<int, 3>{{1, 2, 4}} >= array<int, 3>{{1, 2, 3}}, "");
static_assert(!(array<int, 3>{{1, 2, 2}} >= array<int, 3>{{1, 2, 3}}), "");
} // namespace test_greaterequals
namespace test_tail {
static_assert(array < int, 2 > {{3, 4}} == tail(array < int, 3 > {{2, 3, 4}}), "");
static_assert(array < int, 0 > {{}} == tail(array < int, 1 > {{3}}), "");
}
static_assert(array<int, 2>{{3, 4}} == tail(array<int, 3>{{2, 3, 4}}), "");
static_assert(array<int, 0>{{}} == tail(array<int, 1>{{3}}), "");
} // namespace test_tail
namespace test_prepend {
static_assert(array < int, 3 > {{2, 3, 4}} == prepend(2, array < int, 2 > {{3, 4}}), "");
static_assert(array < int, 1 > {{3}} == prepend(3, array < int, 0 > {{}}), "");
}
static_assert(
array<int, 3>{{2, 3, 4}} == prepend(2, array<int, 2>{{3, 4}}),
"");
static_assert(array<int, 1>{{3}} == prepend(3, array<int, 0>{{}}), "");
} // namespace test_prepend
namespace test_to_std_array {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr int obj2[3] = {3, 5, 6};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(array < int, 3 > {{3, 5, 6}} == to_array(obj2), "");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(array < int, 3 > {{3, 5, 6}} == to_array<int, 3>({3, 5, 6}), "");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr int obj2[3] = {3, 5, 6};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(array<int, 3>{{3, 5, 6}} == to_array(obj2), "");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(array<int, 3>{{3, 5, 6}} == to_array<int, 3>({3, 5, 6}), "");
} // namespace test_to_std_array
}
} // namespace

View File

@ -12,7 +12,7 @@ static_assert(min(5, 3) == 3, "");
static_assert(min(3, 3) == 3, "");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(min(3.0, 3.1) == 3.0, "");
}
} // namespace test_min
namespace test_max {
using c10::guts::max;
@ -23,7 +23,7 @@ static_assert(max(5, 3) == 5, "");
static_assert(max(3, 3) == 3, "");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(max(3.0, 3.1) == 3.1, "");
}
} // namespace test_max
namespace test_if_constexpr {
@ -31,129 +31,125 @@ using c10::guts::if_constexpr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, whenIsTrue_thenReturnsTrueCase) {
EXPECT_EQ(4, if_constexpr<true>([](auto) { return 4; }, [](auto) { return 5; }));
EXPECT_EQ(
4, if_constexpr<true>([](auto) { return 4; }, [](auto) { return 5; }));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, whenIsFalse_thenReturnsFalseCase) {
EXPECT_EQ(5, if_constexpr<false>([](auto) { return 4; }, [](auto) { return 5; }));
EXPECT_EQ(
5, if_constexpr<false>([](auto) { return 4; }, [](auto) { return 5; }));
}
struct MovableOnly final {
int value;
int value;
MovableOnly(int v) : value(v) {}
MovableOnly(MovableOnly&&) = default;
MovableOnly(const MovableOnly&) = delete;
MovableOnly& operator=(MovableOnly&&) = default;
MovableOnly& operator=(const MovableOnly&) = delete;
MovableOnly(int v) : value(v) {}
MovableOnly(MovableOnly&&) = default;
MovableOnly(const MovableOnly&) = delete;
MovableOnly& operator=(MovableOnly&&) = default;
MovableOnly& operator=(const MovableOnly&) = delete;
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, worksWithMovableOnlyTypes_withIdentityArg) {
EXPECT_EQ(
4,
if_constexpr<true>([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); })
.value);
EXPECT_EQ(
5,
if_constexpr<false>([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); })
.value);
EXPECT_EQ(
4,
if_constexpr<true>(
[](auto) { return MovableOnly(4); },
[](auto) { return MovableOnly(5); })
.value);
EXPECT_EQ(
5,
if_constexpr<false>(
[](auto) { return MovableOnly(4); },
[](auto) { return MovableOnly(5); })
.value);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, worksWithMovableOnlyTypes_withoutIdentityArg) {
EXPECT_EQ(
4,
if_constexpr<true>([] { return MovableOnly(4); }, [] { return MovableOnly(5); })
.value);
EXPECT_EQ(
5,
if_constexpr<false>([] { return MovableOnly(4); }, [] { return MovableOnly(5); })
.value);
EXPECT_EQ(
4,
if_constexpr<true>(
[] { return MovableOnly(4); }, [] { return MovableOnly(5); })
.value);
EXPECT_EQ(
5,
if_constexpr<false>(
[] { return MovableOnly(4); }, [] { return MovableOnly(5); })
.value);
}
struct MyClass1 {
int value;
int value;
};
struct MyClass2 {
int val;
int val;
};
template<class T>
template <class T>
int func(T t) {
return if_constexpr<std::is_same<T, MyClass1>::value>(
[&](auto _) { return _(t).value; }, // this code is invalid for T == MyClass2
[&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1
);
return if_constexpr<std::is_same<T, MyClass1>::value>(
[&](auto _) {
return _(t).value;
}, // this code is invalid for T == MyClass2
[&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1
);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, otherCaseCanHaveInvalidCode) {
EXPECT_EQ(8, func(MyClass1{/* .value = */ 8}));
EXPECT_EQ(4, func(MyClass2{/* .val = */ 4}));
EXPECT_EQ(8, func(MyClass1{/* .value = */ 8}));
EXPECT_EQ(4, func(MyClass2{/* .val = */ 4}));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, worksWithoutElseCase_withIdentityArg) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
int var = 5;
if_constexpr<false>(
[&](auto) { var = 3; }
);
EXPECT_EQ(5, var);
if_constexpr<true>(
[&](auto) { var = 3; }
);
EXPECT_EQ(3, var);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
int var = 5;
if_constexpr<false>([&](auto) { var = 3; });
EXPECT_EQ(5, var);
if_constexpr<true>([&](auto) { var = 3; });
EXPECT_EQ(3, var);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, worksWithoutElseCase_withoutIdentityArg) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
int var = 5;
if_constexpr<false>(
[&] { var = 3; }
);
EXPECT_EQ(5, var);
if_constexpr<true>(
[&] { var = 3; }
);
EXPECT_EQ(3, var);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
int var = 5;
if_constexpr<false>([&] { var = 3; });
EXPECT_EQ(5, var);
if_constexpr<true>([&] { var = 3; });
EXPECT_EQ(3, var);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, returnTypeCanDiffer_withIdentityArg) {
auto a_string = if_constexpr<false>(
[&](auto) -> int64_t { return 3; },
[&](auto) -> std::string { return "3"; }
);
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
auto a_string = if_constexpr<false>(
[&](auto) -> int64_t { return 3; },
[&](auto) -> std::string { return "3"; });
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto an_int = if_constexpr<true>(
[&](auto) -> int64_t { return 3; },
[&](auto) -> std::string { return "3"; }
);
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto an_int = if_constexpr<true>(
[&](auto) -> int64_t { return 3; },
[&](auto) -> std::string { return "3"; });
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(if_constexpr, returnTypeCanDiffer_withoutIdentityArg) {
auto a_string = if_constexpr<false>(
[&] () -> int64_t { return 3; },
[&] () -> std::string { return "3"; }
);
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
auto a_string = if_constexpr<false>(
[&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; });
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto an_int = if_constexpr<true>(
[&] () -> int64_t { return 3; },
[&] () -> std::string { return "3"; }
);
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto an_int = if_constexpr<true>(
[&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; });
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
}
}
}
} // namespace test_if_constexpr
} // namespace

View File

@ -10,259 +10,257 @@ TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) {
LeftRight<int> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([] (int& obj) {obj = 5;});
int read = obj.read([] (const int& obj) {return obj;});
obj.write([](int& obj) { obj = 5; });
int read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(5, read);
// check changes are also present in background copy
obj.write([] (int&) {}); // this switches to the background copy
read = obj.read([] (const int& obj) {return obj;});
obj.write([](int&) {}); // this switches to the background copy
read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(5, read);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) {
LeftRight<vector<int>> obj;
LeftRight<vector<int>> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([] (vector<int>& obj) {obj.push_back(5);});
vector<int> read = obj.read([] (const vector<int>& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](vector<int>& obj) { obj.push_back(5); });
vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
EXPECT_EQ((vector<int>{5}), read);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([] (vector<int>& obj) {obj.push_back(6);});
read = obj.read([] (const vector<int>& obj) {return obj;});
EXPECT_EQ((vector<int>{5, 6}), read);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](vector<int>& obj) { obj.push_back(6); });
read = obj.read([](const vector<int>& obj) { return obj; });
EXPECT_EQ((vector<int>{5, 6}), read);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, givenVector_whenWritingReturnsValue_thenValueIsReturned) {
LeftRight<vector<int>> obj;
LeftRight<vector<int>> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto a = obj.write([] (vector<int>&) -> int {return 5;});
static_assert(std::is_same<int, decltype(a)>::value, "");
EXPECT_EQ(5, a);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto a = obj.write([](vector<int>&) -> int { return 5; });
static_assert(std::is_same<int, decltype(a)>::value, "");
EXPECT_EQ(5, a);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, readsCanBeConcurrent) {
LeftRight<int> obj;
std::atomic<int> num_running_readers{0};
LeftRight<int> obj;
std::atomic<int> num_running_readers{0};
std::thread reader1([&] () {
obj.read([&] (const int&) {
++num_running_readers;
while(num_running_readers.load() < 2) {}
});
std::thread reader1([&]() {
obj.read([&](const int&) {
++num_running_readers;
while (num_running_readers.load() < 2) {
}
});
});
std::thread reader2([&] () {
obj.read([&] (const int&) {
++num_running_readers;
while(num_running_readers.load() < 2) {}
});
std::thread reader2([&]() {
obj.read([&](const int&) {
++num_running_readers;
while (num_running_readers.load() < 2) {
}
});
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader1.join();
reader2.join();
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader1.join();
reader2.join();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) {
LeftRight<int> obj;
std::atomic<bool> reader_running{false};
std::atomic<bool> writer_running{false};
LeftRight<int> obj;
std::atomic<bool> reader_running{false};
std::atomic<bool> writer_running{false};
std::thread reader([&] () {
obj.read([&] (const int&) {
reader_running = true;
while(!writer_running.load()) {}
});
std::thread reader([&]() {
obj.read([&](const int&) {
reader_running = true;
while (!writer_running.load()) {
}
});
});
std::thread writer([&] () {
// run read first, write second
while (!reader_running.load()) {}
std::thread writer([&]() {
// run read first, write second
while (!reader_running.load()) {
}
obj.write([&] (int&) {
writer_running = true;
});
});
obj.write([&](int&) { writer_running = true; });
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader.join();
writer.join();
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader.join();
writer.join();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) {
LeftRight<int> obj;
std::atomic<bool> writer_running{false};
std::atomic<bool> reader_running{false};
LeftRight<int> obj;
std::atomic<bool> writer_running{false};
std::atomic<bool> reader_running{false};
std::thread writer([&] () {
obj.read([&] (const int&) {
writer_running = true;
while(!reader_running.load()) {}
});
std::thread writer([&]() {
obj.read([&](const int&) {
writer_running = true;
while (!reader_running.load()) {
}
});
});
std::thread reader([&] () {
// run write first, read second
while (!writer_running.load()) {}
std::thread reader([&]() {
// run write first, read second
while (!writer_running.load()) {
}
obj.read([&] (const int&) {
reader_running = true;
});
});
obj.read([&](const int&) { reader_running = true; });
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
writer.join();
reader.join();
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
writer.join();
reader.join();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) {
LeftRight<int> obj;
std::atomic<bool> first_writer_started{false};
std::atomic<bool> first_writer_finished{false};
LeftRight<int> obj;
std::atomic<bool> first_writer_started{false};
std::atomic<bool> first_writer_finished{false};
std::thread writer1([&] () {
obj.write([&] (int&) {
first_writer_started = true;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::this_thread::sleep_for(std::chrono::milliseconds(50));
first_writer_finished = true;
});
std::thread writer1([&]() {
obj.write([&](int&) {
first_writer_started = true;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::this_thread::sleep_for(std::chrono::milliseconds(50));
first_writer_finished = true;
});
});
std::thread writer2([&] () {
// make sure the other writer runs first
while (!first_writer_started.load()) {}
std::thread writer2([&]() {
// make sure the other writer runs first
while (!first_writer_started.load()) {
}
obj.write([&] (int&) {
// expect the other writer finished before this one starts
EXPECT_TRUE(first_writer_finished.load());
});
obj.write([&](int&) {
// expect the other writer finished before this one starts
EXPECT_TRUE(first_writer_finished.load());
});
});
writer1.join();
writer2.join();
writer1.join();
writer2.join();
}
namespace {
class MyException : public std::exception {};
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) {
LeftRight<int> obj;
LeftRight<int> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.read([](const int&) {throw MyException();}),
MyException
);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(obj.read([](const int&) { throw MyException(); }), MyException);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) {
LeftRight<int> obj;
LeftRight<int> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([](int&) {throw MyException();}),
MyException
);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(obj.write([](int&) { throw MyException(); }), MyException);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) {
LeftRight<int> obj;
TEST(
LeftRightTest,
givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) {
LeftRight<int> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](int& obj) {obj = 5;});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](int& obj) { obj = 5; });
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([](int& obj) {
obj = 6;
throw MyException();
}),
MyException
);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([](int& obj) {
obj = 6;
throw MyException();
}),
MyException);
// check reading it returns old value
int read = obj.read([] (const int& obj) {return obj;});
EXPECT_EQ(5, read);
// check reading it returns old value
int read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(5, read);
// check changes are also present in background copy
obj.write([] (int&) {}); // this switches to the background copy
read = obj.read([] (const int& obj) {return obj;});
EXPECT_EQ(5, read);
// check changes are also present in background copy
obj.write([](int&) {}); // this switches to the background copy
read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(5, read);
}
// note: each write is executed twice, on the foreground and background copy.
// We need to test a thrown exception in either call is handled correctly.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) {
LeftRight<int> obj;
TEST(
LeftRightTest,
givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) {
LeftRight<int> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](int& obj) {obj = 5;});
bool write_called = false;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](int& obj) { obj = 5; });
bool write_called = false;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([&](int& obj) {
obj = 6;
if (write_called) {
// this is the second time the write callback is executed
throw MyException();
} else {
write_called = true;
}
}),
MyException
);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([&](int& obj) {
obj = 6;
if (write_called) {
// this is the second time the write callback is executed
throw MyException();
} else {
write_called = true;
}
}),
MyException);
// check reading it returns new value
int read = obj.read([] (const int& obj) {return obj;});
EXPECT_EQ(6, read);
// check reading it returns new value
int read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(6, read);
// check changes are also present in background copy
obj.write([] (int&) {}); // this switches to the background copy
read = obj.read([] (const int& obj) {return obj;});
EXPECT_EQ(6, read);
// check changes are also present in background copy
obj.write([](int&) {}); // this switches to the background copy
read = obj.read([](const int& obj) { return obj; });
EXPECT_EQ(6, read);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) {
LeftRight<vector<int>> obj;
LeftRight<vector<int>> obj;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](vector<int>& obj) {obj.push_back(5);});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
obj.write([](vector<int>& obj) { obj.push_back(5); });
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([](vector<int>& obj) {
obj.push_back(6);
throw MyException();
}),
MyException
);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
obj.write([](vector<int>& obj) {
obj.push_back(6);
throw MyException();
}),
MyException);
// check reading it returns old value
vector<int> read = obj.read([] (const vector<int>& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
// check reading it returns old value
vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
EXPECT_EQ((vector<int>{5}), read);
// check changes are also present in background copy
obj.write([] (vector<int>&) {}); // this switches to the background copy
read = obj.read([] (const vector<int>& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
// check changes are also present in background copy
obj.write([](vector<int>&) {}); // this switches to the background copy
read = obj.read([](const vector<int>& obj) { return obj; });
EXPECT_EQ((vector<int>{5}), read);
}

File diff suppressed because it is too large Load Diff

View File

@ -61,12 +61,10 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, TopLevelName) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Dummy>().find("Dummy")
);
}
EXPECT_NE(
string_view::npos, get_fully_qualified_type_name<Dummy>().find("Dummy"));
}
} // namespace test_top_level_name
namespace test_nested_name {
struct Dummy final {};
@ -79,10 +77,9 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, NestedName) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy"));
}
} // namespace test_nested_name
@ -105,16 +102,14 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, TypeTemplateParameter) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Outer<Inner>>().find(
"test_type_template_parameter::Outer")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Outer<Inner>>().find(
"test_type_template_parameter::Inner")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Outer<Inner>>().find(
"test_type_template_parameter::Outer"));
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Outer<Inner>>().find(
"test_type_template_parameter::Inner"));
}
} // namespace test_type_template_parameter
@ -131,10 +126,9 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, NonTypeTemplateParameter) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Class<38474355>>().find("38474355")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Class<38474355>>().find("38474355"));
}
} // namespace test_nontype_template_parameter
@ -164,21 +158,18 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, TypeComputationsAreResolved) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<typename Type<int>::type>().find("int")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<typename Type<int>::type>().find("*")
);
// but with remove_pointer applied, there is no '*' in the type name anymore
EXPECT_EQ(
string_view::npos,
get_fully_qualified_type_name<
typename std::remove_pointer<typename Type<int>::type>::type>()
.find("*")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<typename Type<int>::type>().find("int"));
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<typename Type<int>::type>().find("*"));
// but with remove_pointer applied, there is no '*' in the type name anymore
EXPECT_EQ(
string_view::npos,
get_fully_qualified_type_name<
typename std::remove_pointer<typename Type<int>::type>::type>()
.find("*"));
}
struct Functor final {
@ -193,11 +184,10 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, FunctionTypeComputationsAreResolved) {
EXPECT_EQ(
get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>(),
get_fully_qualified_type_name<
typename c10::guts::infer_function_traits_t<Functor>::func_type>()
);
EXPECT_EQ(
get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>(),
get_fully_qualified_type_name<
typename c10::guts::infer_function_traits_t<Functor>::func_type>());
}
} // namespace test_type_computations_are_resolved
@ -218,16 +208,14 @@ static_assert(
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeIndex, FunctionArgumentsAndReturns) {
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Dummy(int)>().find(
"test_function_arguments_and_returns::Dummy")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<void(Dummy)>().find(
"test_function_arguments_and_returns::Dummy")
);
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<Dummy(int)>().find(
"test_function_arguments_and_returns::Dummy"));
EXPECT_NE(
string_view::npos,
get_fully_qualified_type_name<void(Dummy)>().find(
"test_function_arguments_and_returns::Dummy"));
}
} // namespace test_function_arguments_and_returns
} // namespace

View File

@ -5,204 +5,382 @@
using namespace c10::guts::typelist;
namespace test_size {
class MyClass {};
static_assert(0 == size<typelist<>>::value, "");
static_assert(1 == size<typelist<int>>::value, "");
static_assert(3 == size<typelist<int, float&, const MyClass&&>>::value, "");
}
class MyClass {};
static_assert(0 == size<typelist<>>::value, "");
static_assert(1 == size<typelist<int>>::value, "");
static_assert(3 == size<typelist<int, float&, const MyClass&&>>::value, "");
} // namespace test_size
namespace test_from_tuple {
class MyClass {};
static_assert(std::is_same<typelist<int, float&, const MyClass&&>, from_tuple_t<std::tuple<int, float&, const MyClass&&>>>::value, "");
static_assert(std::is_same<typelist<>, from_tuple_t<std::tuple<>>>::value, "");
}
class MyClass {};
static_assert(
std::is_same<
typelist<int, float&, const MyClass&&>,
from_tuple_t<std::tuple<int, float&, const MyClass&&>>>::value,
"");
static_assert(std::is_same<typelist<>, from_tuple_t<std::tuple<>>>::value, "");
} // namespace test_from_tuple
namespace test_to_tuple {
class MyClass {};
static_assert(std::is_same<std::tuple<int, float&, const MyClass&&>, to_tuple_t<typelist<int, float&, const MyClass&&>>>::value, "");
static_assert(std::is_same<std::tuple<>, to_tuple_t<typelist<>>>::value, "");
}
class MyClass {};
static_assert(
std::is_same<
std::tuple<int, float&, const MyClass&&>,
to_tuple_t<typelist<int, float&, const MyClass&&>>>::value,
"");
static_assert(std::is_same<std::tuple<>, to_tuple_t<typelist<>>>::value, "");
} // namespace test_to_tuple
namespace test_concat {
class MyClass {};
static_assert(std::is_same<typelist<>, concat_t<>>::value, "");
static_assert(std::is_same<typelist<>, concat_t<typelist<>>>::value, "");
static_assert(std::is_same<typelist<>, concat_t<typelist<>, typelist<>>>::value, "");
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>>>::value, "");
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>, typelist<>>>::value, "");
static_assert(std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>>>::value, "");
static_assert(std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>, typelist<>>>::value, "");
static_assert(std::is_same<typelist<int, float&>, concat_t<typelist<int>, typelist<float&>>>::value, "");
static_assert(std::is_same<typelist<int, float&>, concat_t<typelist<>, typelist<int, float&>, typelist<>>>::value, "");
static_assert(std::is_same<typelist<int, float&, const MyClass&&>, concat_t<typelist<>, typelist<int, float&>, typelist<const MyClass&&>>>::value, "");
}
class MyClass {};
static_assert(std::is_same<typelist<>, concat_t<>>::value, "");
static_assert(std::is_same<typelist<>, concat_t<typelist<>>>::value, "");
static_assert(
std::is_same<typelist<>, concat_t<typelist<>, typelist<>>>::value,
"");
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>>>::value, "");
static_assert(
std::is_same<typelist<int>, concat_t<typelist<int>, typelist<>>>::value,
"");
static_assert(
std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>>>::value,
"");
static_assert(
std::is_same<
typelist<int>,
concat_t<typelist<>, typelist<int>, typelist<>>>::value,
"");
static_assert(
std::is_same<
typelist<int, float&>,
concat_t<typelist<int>, typelist<float&>>>::value,
"");
static_assert(
std::is_same<
typelist<int, float&>,
concat_t<typelist<>, typelist<int, float&>, typelist<>>>::value,
"");
static_assert(
std::is_same<
typelist<int, float&, const MyClass&&>,
concat_t<
typelist<>,
typelist<int, float&>,
typelist<const MyClass&&>>>::value,
"");
} // namespace test_concat
namespace test_filter {
class MyClass {};
static_assert(std::is_same<typelist<>, filter_t<std::is_reference, typelist<>>>::value, "");
static_assert(std::is_same<typelist<>, filter_t<std::is_reference, typelist<int, float, double, MyClass>>>::value, "");
static_assert(std::is_same<typelist<float&, const MyClass&&>, filter_t<std::is_reference, typelist<int, float&, double, const MyClass&&>>>::value, "");
}
class MyClass {};
static_assert(
std::is_same<typelist<>, filter_t<std::is_reference, typelist<>>>::value,
"");
static_assert(
std::is_same<
typelist<>,
filter_t<std::is_reference, typelist<int, float, double, MyClass>>>::
value,
"");
static_assert(
std::is_same<
typelist<float&, const MyClass&&>,
filter_t<
std::is_reference,
typelist<int, float&, double, const MyClass&&>>>::value,
"");
} // namespace test_filter
namespace test_count_if {
class MyClass final {};
static_assert(count_if<std::is_reference, typelist<int, bool&, const MyClass&&, float, double>>::value == 2, "");
static_assert(count_if<std::is_reference, typelist<int, bool>>::value == 0, "");
static_assert(count_if<std::is_reference, typelist<>>::value == 0, "");
}
class MyClass final {};
static_assert(
count_if<
std::is_reference,
typelist<int, bool&, const MyClass&&, float, double>>::value == 2,
"");
static_assert(count_if<std::is_reference, typelist<int, bool>>::value == 0, "");
static_assert(count_if<std::is_reference, typelist<>>::value == 0, "");
} // namespace test_count_if
namespace test_true_for_each_type {
template<class> class Test;
class MyClass {};
static_assert(all<std::is_reference, typelist<int&, const float&&, const MyClass&>>::value, "");
static_assert(!all<std::is_reference, typelist<int&, const float, const MyClass&>>::value, "");
static_assert(all<std::is_reference, typelist<>>::value, "");
}
template <class>
class Test;
class MyClass {};
static_assert(
all<std::is_reference,
typelist<int&, const float&&, const MyClass&>>::value,
"");
static_assert(
!all<std::is_reference, typelist<int&, const float, const MyClass&>>::value,
"");
static_assert(all<std::is_reference, typelist<>>::value, "");
} // namespace test_true_for_each_type
namespace test_true_for_any_type {
template<class> class Test;
class MyClass {};
static_assert(true_for_any_type<std::is_reference, typelist<int&, const float&&, const MyClass&>>::value, "");
static_assert(true_for_any_type<std::is_reference, typelist<int&, const float, const MyClass&>>::value, "");
static_assert(!true_for_any_type<std::is_reference, typelist<int, const float, const MyClass>>::value, "");
static_assert(!true_for_any_type<std::is_reference, typelist<>>::value, "");
}
template <class>
class Test;
class MyClass {};
static_assert(
true_for_any_type<
std::is_reference,
typelist<int&, const float&&, const MyClass&>>::value,
"");
static_assert(
true_for_any_type<
std::is_reference,
typelist<int&, const float, const MyClass&>>::value,
"");
static_assert(
!true_for_any_type<
std::is_reference,
typelist<int, const float, const MyClass>>::value,
"");
static_assert(!true_for_any_type<std::is_reference, typelist<>>::value, "");
} // namespace test_true_for_any_type
namespace test_map {
class MyClass {};
static_assert(std::is_same<typelist<>, map_t<std::add_lvalue_reference_t, typelist<>>>::value, "");
static_assert(std::is_same<typelist<int&>, map_t<std::add_lvalue_reference_t, typelist<int>>>::value, "");
static_assert(std::is_same<typelist<int&, double&, const MyClass&>, map_t<std::add_lvalue_reference_t, typelist<int, double, const MyClass>>>::value, "");
}
class MyClass {};
static_assert(
std::is_same<typelist<>, map_t<std::add_lvalue_reference_t, typelist<>>>::
value,
"");
static_assert(
std::is_same<
typelist<int&>,
map_t<std::add_lvalue_reference_t, typelist<int>>>::value,
"");
static_assert(
std::is_same<
typelist<int&, double&, const MyClass&>,
map_t<
std::add_lvalue_reference_t,
typelist<int, double, const MyClass>>>::value,
"");
} // namespace test_map
namespace test_head {
class MyClass {};
static_assert(std::is_same<int, head_t<typelist<int, double>>>::value, "");
static_assert(std::is_same<const MyClass&, head_t<typelist<const MyClass&, double>>>::value, "");
static_assert(std::is_same<MyClass&&, head_t<typelist<MyClass&&, MyClass>>>::value, "");
static_assert(std::is_same<bool, head_t<typelist<bool>>>::value, "");
}
class MyClass {};
static_assert(std::is_same<int, head_t<typelist<int, double>>>::value, "");
static_assert(
std::is_same<const MyClass&, head_t<typelist<const MyClass&, double>>>::
value,
"");
static_assert(
std::is_same<MyClass&&, head_t<typelist<MyClass&&, MyClass>>>::value,
"");
static_assert(std::is_same<bool, head_t<typelist<bool>>>::value, "");
} // namespace test_head
namespace test_head_with_default {
class MyClass {};
static_assert(std::is_same<int, head_with_default_t<bool, typelist<int, double>>>::value, "");
static_assert(std::is_same<const MyClass&, head_with_default_t<bool, typelist<const MyClass&, double>>>::value, "");
static_assert(std::is_same<MyClass&&, head_with_default_t<bool, typelist<MyClass&&, MyClass>>>::value, "");
static_assert(std::is_same<int, head_with_default_t<bool, typelist<int>>>::value, "");
static_assert(std::is_same<bool, head_with_default_t<bool, typelist<>>>::value, "");
}
class MyClass {};
static_assert(
std::is_same<int, head_with_default_t<bool, typelist<int, double>>>::value,
"");
static_assert(
std::is_same<
const MyClass&,
head_with_default_t<bool, typelist<const MyClass&, double>>>::value,
"");
static_assert(
std::is_same<
MyClass&&,
head_with_default_t<bool, typelist<MyClass&&, MyClass>>>::value,
"");
static_assert(
std::is_same<int, head_with_default_t<bool, typelist<int>>>::value,
"");
static_assert(
std::is_same<bool, head_with_default_t<bool, typelist<>>>::value,
"");
} // namespace test_head_with_default
namespace test_reverse {
class MyClass {};
static_assert(std::is_same<
typelist<int, double, MyClass*, const MyClass&&>,
reverse_t<typelist<const MyClass&&, MyClass*, double, int>>
>::value, "");
static_assert(std::is_same<
typelist<>,
reverse_t<typelist<>>
>::value, "");
}
class MyClass {};
static_assert(
std::is_same<
typelist<int, double, MyClass*, const MyClass&&>,
reverse_t<typelist<const MyClass&&, MyClass*, double, int>>>::value,
"");
static_assert(std::is_same<typelist<>, reverse_t<typelist<>>>::value, "");
} // namespace test_reverse
namespace test_map_types_to_values {
struct map_to_size {
template<class T> constexpr size_t operator()(T) const {return sizeof(typename T::type);}
};
struct map_to_size {
template <class T>
constexpr size_t operator()(T) const {
return sizeof(typename T::type);
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_sametype) {
auto sizes =
map_types_to_values<typelist<int64_t, bool, uint32_t>>(map_to_size());
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::tuple<size_t, size_t, size_t> expected(8, 1, 4);
static_assert(std::is_same<decltype(expected), decltype(sizes)>::value, "");
EXPECT_EQ(expected, sizes);
}
struct map_make_shared {
template<class T> std::shared_ptr<typename T::type> operator()(T) {
return std::make_shared<typename T::type>();
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_differenttypes) {
auto shared_ptrs =
map_types_to_values<typelist<int, double>>(map_make_shared());
static_assert(std::is_same<std::tuple<std::shared_ptr<int>, std::shared_ptr<double>>, decltype(shared_ptrs)>::value, "");
}
struct Class1 {static int func() {return 3;}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
struct Class2 {static double func() {return 2.0;}};
struct mapper_call_func {
template<class T> decltype(auto) operator()(T) { return T::type::func(); }
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_members) {
auto result =
map_types_to_values<typelist<Class1, Class2>>(mapper_call_func());
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::tuple<int, double> expected(3, 2.0);
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
EXPECT_EQ(expected, result);
}
struct mapper_call_nonexistent_function {
template<class T> decltype(auto) operator()(T) { return T::type::this_doesnt_exist(); }
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_empty) {
auto result =
map_types_to_values<typelist<>>(mapper_call_nonexistent_function());
std::tuple<> expected;
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
EXPECT_EQ(expected, result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_sametype) {
auto sizes =
map_types_to_values<typelist<int64_t, bool, uint32_t>>(map_to_size());
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::tuple<size_t, size_t, size_t> expected(8, 1, 4);
static_assert(std::is_same<decltype(expected), decltype(sizes)>::value, "");
EXPECT_EQ(expected, sizes);
}
struct map_make_shared {
template <class T>
std::shared_ptr<typename T::type> operator()(T) {
return std::make_shared<typename T::type>();
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_differenttypes) {
auto shared_ptrs =
map_types_to_values<typelist<int, double>>(map_make_shared());
static_assert(
std::is_same<
std::tuple<std::shared_ptr<int>, std::shared_ptr<double>>,
decltype(shared_ptrs)>::value,
"");
}
struct Class1 {
static int func() {
return 3;
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
struct Class2 {
static double func() {
return 2.0;
}
};
struct mapper_call_func {
template <class T>
decltype(auto) operator()(T) {
return T::type::func();
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_members) {
auto result =
map_types_to_values<typelist<Class1, Class2>>(mapper_call_func());
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::tuple<int, double> expected(3, 2.0);
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
EXPECT_EQ(expected, result);
}
struct mapper_call_nonexistent_function {
template <class T>
decltype(auto) operator()(T) {
return T::type::this_doesnt_exist();
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TypeListTest, MapTypesToValues_empty) {
auto result =
map_types_to_values<typelist<>>(mapper_call_nonexistent_function());
std::tuple<> expected;
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
EXPECT_EQ(expected, result);
}
} // namespace test_map_types_to_values
namespace test_find_if {
static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
static_assert(0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value, "");
static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, "");
static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, "");
}
static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
static_assert(
0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value,
"");
static_assert(
2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value,
"");
static_assert(
3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value,
"");
} // namespace test_find_if
namespace test_contains {
static_assert(contains<typelist<double>, double>::value, "");
static_assert(contains<typelist<int, double>, double>::value, "");
static_assert(!contains<typelist<int, double>, float>::value, "");
static_assert(!contains<typelist<>, double>::value, "");
}
static_assert(contains<typelist<double>, double>::value, "");
static_assert(contains<typelist<int, double>, double>::value, "");
static_assert(!contains<typelist<int, double>, float>::value, "");
static_assert(!contains<typelist<>, double>::value, "");
} // namespace test_contains
namespace test_take {
static_assert(std::is_same<typelist<>, take_t<typelist<>, 0>>::value, "");
static_assert(std::is_same<typelist<>, take_t<typelist<int64_t>, 0>>::value, "");
static_assert(std::is_same<typelist<int64_t>, take_t<typelist<int64_t>, 1>>::value, "");
static_assert(std::is_same<typelist<>, take_t<typelist<int64_t, int32_t>, 0>>::value, "");
static_assert(std::is_same<typelist<int64_t>, take_t<typelist<int64_t, int32_t>, 1>>::value, "");
static_assert(std::is_same<typelist<int64_t, int32_t>, take_t<typelist<int64_t, int32_t>, 2>>::value, "");
}
static_assert(std::is_same<typelist<>, take_t<typelist<>, 0>>::value, "");
static_assert(
std::is_same<typelist<>, take_t<typelist<int64_t>, 0>>::value,
"");
static_assert(
std::is_same<typelist<int64_t>, take_t<typelist<int64_t>, 1>>::value,
"");
static_assert(
std::is_same<typelist<>, take_t<typelist<int64_t, int32_t>, 0>>::value,
"");
static_assert(
std::is_same<typelist<int64_t>, take_t<typelist<int64_t, int32_t>, 1>>::
value,
"");
static_assert(
std::is_same<
typelist<int64_t, int32_t>,
take_t<typelist<int64_t, int32_t>, 2>>::value,
"");
} // namespace test_take
namespace test_drop {
static_assert(std::is_same<typelist<>, drop_t<typelist<>, 0>>::value, "");
static_assert(std::is_same<typelist<int64_t>, drop_t<typelist<int64_t>, 0>>::value, "");
static_assert(std::is_same<typelist<>, drop_t<typelist<int64_t>, 1>>::value, "");
static_assert(std::is_same<typelist<int64_t, int32_t>, drop_t<typelist<int64_t, int32_t>, 0>>::value, "");
static_assert(std::is_same<typelist<int32_t>, drop_t<typelist<int64_t, int32_t>, 1>>::value, "");
static_assert(std::is_same<typelist<>, drop_t<typelist<int64_t, int32_t>, 2>>::value, "");
}
static_assert(std::is_same<typelist<>, drop_t<typelist<>, 0>>::value, "");
static_assert(
std::is_same<typelist<int64_t>, drop_t<typelist<int64_t>, 0>>::value,
"");
static_assert(
std::is_same<typelist<>, drop_t<typelist<int64_t>, 1>>::value,
"");
static_assert(
std::is_same<
typelist<int64_t, int32_t>,
drop_t<typelist<int64_t, int32_t>, 0>>::value,
"");
static_assert(
std::is_same<typelist<int32_t>, drop_t<typelist<int64_t, int32_t>, 1>>::
value,
"");
static_assert(
std::is_same<typelist<>, drop_t<typelist<int64_t, int32_t>, 2>>::value,
"");
} // namespace test_drop
namespace test_drop_if_nonempty {
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 0>>::value, "");
static_assert(std::is_same<typelist<int64_t>, drop_if_nonempty_t<typelist<int64_t>, 0>>::value, "");
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t>, 1>>::value, "");
static_assert(std::is_same<typelist<int64_t, int32_t>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 0>>::value, "");
static_assert(std::is_same<typelist<int32_t>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 1>>::value, "");
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 2>>::value, "");
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 1>>::value, "");
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 3>>::value, "");
}
static_assert(
std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 0>>::value,
"");
static_assert(
std::is_same<typelist<int64_t>, drop_if_nonempty_t<typelist<int64_t>, 0>>::
value,
"");
static_assert(
std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t>, 1>>::value,
"");
static_assert(
std::is_same<
typelist<int64_t, int32_t>,
drop_if_nonempty_t<typelist<int64_t, int32_t>, 0>>::value,
"");
static_assert(
std::is_same<
typelist<int32_t>,
drop_if_nonempty_t<typelist<int64_t, int32_t>, 1>>::value,
"");
static_assert(
std::is_same<
typelist<>,
drop_if_nonempty_t<typelist<int64_t, int32_t>, 2>>::value,
"");
static_assert(
std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 1>>::value,
"");
static_assert(
std::is_same<
typelist<>,
drop_if_nonempty_t<typelist<int64_t, int32_t>, 3>>::value,
"");
} // namespace test_drop_if_nonempty

View File

@ -6,152 +6,190 @@ using namespace c10::guts;
namespace {
namespace test_is_equality_comparable {
class NotEqualityComparable {};
class EqualityComparable {};
class NotEqualityComparable {};
class EqualityComparable {};
inline bool operator==(const EqualityComparable &, const EqualityComparable &) { return false; }
static_assert(!is_equality_comparable<NotEqualityComparable>::value, "");
static_assert(is_equality_comparable<EqualityComparable>::value, "");
static_assert(is_equality_comparable<int>::value, "");
// v_ just exists to silence a compiler warning about operator==(EqualityComparable, EqualityComparable) not being needed
const bool v_ = EqualityComparable() == EqualityComparable();
inline bool operator==(const EqualityComparable&, const EqualityComparable&) {
return false;
}
static_assert(!is_equality_comparable<NotEqualityComparable>::value, "");
static_assert(is_equality_comparable<EqualityComparable>::value, "");
static_assert(is_equality_comparable<int>::value, "");
// v_ just exists to silence a compiler warning about
// operator==(EqualityComparable, EqualityComparable) not being needed
const bool v_ = EqualityComparable() == EqualityComparable();
} // namespace test_is_equality_comparable
namespace test_is_hashable {
class NotHashable {};
class Hashable {};
}
}
class NotHashable {};
class Hashable {};
} // namespace test_is_hashable
} // namespace
namespace std {
template<> struct hash<test_is_hashable::Hashable> final {
size_t operator()(const test_is_hashable::Hashable &) { return 0; }
};
}
template <>
struct hash<test_is_hashable::Hashable> final {
size_t operator()(const test_is_hashable::Hashable&) {
return 0;
}
};
} // namespace std
namespace {
namespace test_is_hashable {
static_assert(is_hashable<int>::value, "");
static_assert(is_hashable<Hashable>::value, "");
static_assert(!is_hashable<NotHashable>::value, "");
}
static_assert(is_hashable<int>::value, "");
static_assert(is_hashable<Hashable>::value, "");
static_assert(!is_hashable<NotHashable>::value, "");
} // namespace test_is_hashable
namespace test_is_function_type {
class MyClass {};
struct Functor {
void operator()() {}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
auto lambda = [] () {};
// func() and func__ just exists to silence a compiler warning about lambda being unused
bool func() {
lambda();
return true;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
bool func__ = func();
static_assert(is_function_type<void()>::value, "");
static_assert(is_function_type<int()>::value, "");
static_assert(is_function_type<MyClass()>::value, "");
static_assert(is_function_type<void(MyClass)>::value, "");
static_assert(is_function_type<void(int)>::value, "");
static_assert(is_function_type<void(void*)>::value, "");
static_assert(is_function_type<int()>::value, "");
static_assert(is_function_type<int(MyClass)>::value, "");
static_assert(is_function_type<int(const MyClass&)>::value, "");
static_assert(is_function_type<int(MyClass&&)>::value, "");
static_assert(is_function_type<MyClass&&()>::value, "");
static_assert(is_function_type<MyClass&&(MyClass&&)>::value, "");
static_assert(is_function_type<const MyClass&(int, float, MyClass)>::value, "");
static_assert(!is_function_type<void>::value, "");
static_assert(!is_function_type<int>::value, "");
static_assert(!is_function_type<MyClass>::value, "");
static_assert(!is_function_type<void*>::value, "");
static_assert(!is_function_type<const MyClass&>::value, "");
static_assert(!is_function_type<MyClass&&>::value, "");
static_assert(!is_function_type<void (*)()>::value, "function pointers aren't plain functions");
static_assert(!is_function_type<Functor>::value, "Functors aren't plain functions");
static_assert(!is_function_type<decltype(lambda)>::value, "Lambdas aren't plain functions");
class MyClass {};
struct Functor {
void operator()() {}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
auto lambda = []() {};
// func() and func__ just exists to silence a compiler warning about lambda
// being unused
bool func() {
lambda();
return true;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
bool func__ = func();
static_assert(is_function_type<void()>::value, "");
static_assert(is_function_type<int()>::value, "");
static_assert(is_function_type<MyClass()>::value, "");
static_assert(is_function_type<void(MyClass)>::value, "");
static_assert(is_function_type<void(int)>::value, "");
static_assert(is_function_type<void(void*)>::value, "");
static_assert(is_function_type<int()>::value, "");
static_assert(is_function_type<int(MyClass)>::value, "");
static_assert(is_function_type<int(const MyClass&)>::value, "");
static_assert(is_function_type<int(MyClass&&)>::value, "");
static_assert(is_function_type < MyClass && () > ::value, "");
static_assert(is_function_type < MyClass && (MyClass &&) > ::value, "");
static_assert(is_function_type<const MyClass&(int, float, MyClass)>::value, "");
static_assert(!is_function_type<void>::value, "");
static_assert(!is_function_type<int>::value, "");
static_assert(!is_function_type<MyClass>::value, "");
static_assert(!is_function_type<void*>::value, "");
static_assert(!is_function_type<const MyClass&>::value, "");
static_assert(!is_function_type<MyClass&&>::value, "");
static_assert(
!is_function_type<void (*)()>::value,
"function pointers aren't plain functions");
static_assert(
!is_function_type<Functor>::value,
"Functors aren't plain functions");
static_assert(
!is_function_type<decltype(lambda)>::value,
"Lambdas aren't plain functions");
} // namespace test_is_function_type
namespace test_is_instantiation_of {
class MyClass {};
template<class T> class Single {};
template<class T1, class T2> class Double {};
template<class... T> class Multiple {};
class MyClass {};
template <class T>
class Single {};
template <class T1, class T2>
class Double {};
template <class... T>
class Multiple {};
static_assert(is_instantiation_of<Single, Single<void>>::value, "");
static_assert(is_instantiation_of<Single, Single<MyClass>>::value, "");
static_assert(is_instantiation_of<Single, Single<int>>::value, "");
static_assert(is_instantiation_of<Single, Single<void*>>::value, "");
static_assert(is_instantiation_of<Single, Single<int*>>::value, "");
static_assert(is_instantiation_of<Single, Single<const MyClass&>>::value, "");
static_assert(is_instantiation_of<Single, Single<MyClass&&>>::value, "");
static_assert(is_instantiation_of<Double, Double<int, void>>::value, "");
static_assert(is_instantiation_of<Double, Double<const int&, MyClass*>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<int>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass, void*>>::value, "");
static_assert(is_instantiation_of<Single, Single<void>>::value, "");
static_assert(is_instantiation_of<Single, Single<MyClass>>::value, "");
static_assert(is_instantiation_of<Single, Single<int>>::value, "");
static_assert(is_instantiation_of<Single, Single<void*>>::value, "");
static_assert(is_instantiation_of<Single, Single<int*>>::value, "");
static_assert(is_instantiation_of<Single, Single<const MyClass&>>::value, "");
static_assert(is_instantiation_of<Single, Single<MyClass&&>>::value, "");
static_assert(is_instantiation_of<Double, Double<int, void>>::value, "");
static_assert(
is_instantiation_of<Double, Double<const int&, MyClass*>>::value,
"");
static_assert(is_instantiation_of<Multiple, Multiple<>>::value, "");
static_assert(is_instantiation_of<Multiple, Multiple<int>>::value, "");
static_assert(
is_instantiation_of<Multiple, Multiple<MyClass&, int>>::value,
"");
static_assert(
is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass>>::value,
"");
static_assert(
is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass, void*>>::
value,
"");
static_assert(!is_instantiation_of<Single, Double<int, int>>::value, "");
static_assert(!is_instantiation_of<Single, Double<int, void>>::value, "");
static_assert(!is_instantiation_of<Single, Multiple<int>>::value, "");
static_assert(!is_instantiation_of<Double, Single<int>>::value, "");
static_assert(!is_instantiation_of<Double, Multiple<int, int>>::value, "");
static_assert(!is_instantiation_of<Double, Multiple<>>::value, "");
static_assert(!is_instantiation_of<Multiple, Double<int, int>>::value, "");
static_assert(!is_instantiation_of<Multiple, Single<int>>::value, "");
}
static_assert(!is_instantiation_of<Single, Double<int, int>>::value, "");
static_assert(!is_instantiation_of<Single, Double<int, void>>::value, "");
static_assert(!is_instantiation_of<Single, Multiple<int>>::value, "");
static_assert(!is_instantiation_of<Double, Single<int>>::value, "");
static_assert(!is_instantiation_of<Double, Multiple<int, int>>::value, "");
static_assert(!is_instantiation_of<Double, Multiple<>>::value, "");
static_assert(!is_instantiation_of<Multiple, Double<int, int>>::value, "");
static_assert(!is_instantiation_of<Multiple, Single<int>>::value, "");
} // namespace test_is_instantiation_of
namespace test_is_type_condition {
template<class> class NotATypeCondition {};
static_assert(is_type_condition<std::is_reference>::value, "");
static_assert(!is_type_condition<NotATypeCondition>::value, "");
}
}
template <class>
class NotATypeCondition {};
static_assert(is_type_condition<std::is_reference>::value, "");
static_assert(!is_type_condition<NotATypeCondition>::value, "");
} // namespace test_is_type_condition
} // namespace
namespace test_lambda_is_stateless {
template<class Result, class... Args>
struct MyStatelessFunctor final {
Result operator()(Args...) {}
};
template <class Result, class... Args>
struct MyStatelessFunctor final {
Result operator()(Args...) {}
};
template<class Result, class... Args>
struct MyStatelessConstFunctor final {
Result operator()(Args...) const {}
};
template <class Result, class... Args>
struct MyStatelessConstFunctor final {
Result operator()(Args...) const {}
};
void func() {
auto stateless_lambda = [] (int a) {return a;};
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
void func() {
auto stateless_lambda = [](int a) { return a; };
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
int b = 4;
auto stateful_lambda_1 = [&] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
int b = 4;
auto stateful_lambda_1 = [&](int a) { return a + b; };
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
auto stateful_lambda_2 = [=] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
auto stateful_lambda_2 = [=](int a) { return a + b; };
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
auto stateful_lambda_3 = [b] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
auto stateful_lambda_3 = [b](int a) { return a + b; };
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
static_assert(!is_stateless_lambda<MyStatelessFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(
!is_stateless_lambda<MyStatelessFunctor<int, int>>::value,
"even if stateless, a functor is not a lambda, so it's false");
static_assert(
!is_stateless_lambda<MyStatelessFunctor<void, int>>::value,
"even if stateless, a functor is not a lambda, so it's false");
static_assert(
!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value,
"even if stateless, a functor is not a lambda, so it's false");
static_assert(
!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value,
"even if stateless, a functor is not a lambda, so it's false");
class Dummy final {};
static_assert(!is_stateless_lambda<Dummy>::value, "A non-functor type is also not a lambda");
class Dummy final {};
static_assert(
!is_stateless_lambda<Dummy>::value,
"A non-functor type is also not a lambda");
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
using Func = int(int);
static_assert(!is_stateless_lambda<Func>::value, "A function is not a lambda");
static_assert(!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
}
using Func = int(int);
static_assert(
!is_stateless_lambda<Func>::value, "A function is not a lambda");
static_assert(
!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
}
} // namespace test_lambda_is_stateless

View File

@ -11,70 +11,73 @@ using namespace ::testing;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(accumulate_test, vector_test) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::vector<int> ints = {1, 2, 3, 4, 5};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::vector<int> ints = {1, 2, 3, 4, 5};
EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5);
EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5);
EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5);
EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5);
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
EXPECT_EQ(
c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
EXPECT_EQ(c10::sum_integers(ints.begin()+1, ints.end()-1), 2+3+4);
EXPECT_EQ(c10::multiply_integers(ints.begin()+1, ints.end()-1), 2*3*4);
EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4);
EXPECT_EQ(
c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4);
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5);
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3);
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4);
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4);
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(accumulate_test, list_test) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::list<int> ints = {1, 2, 3, 4, 5};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::list<int> ints = {1, 2, 3, 4, 5};
EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5);
EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5);
EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5);
EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5);
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
EXPECT_EQ(
c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5);
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3);
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4);
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4);
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(accumulate_test, base_cases) {
std::vector<int> ints = {};
std::vector<int> ints = {};
EXPECT_EQ(c10::sum_integers(ints), 0);
EXPECT_EQ(c10::multiply_integers(ints), 1);
EXPECT_EQ(c10::sum_integers(ints), 0);
EXPECT_EQ(c10::multiply_integers(ints), 1);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(accumulate_test, errors) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::vector<int> ints = {1,2,3,4,5};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::vector<int> ints = {1, 2, 3, 4, 5};
#ifndef NDEBUG
EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
#endif
#ifndef NDEBUG
EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
#endif
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
EXPECT_EQ(c10::numelements_from_dim(10, ints),1);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
EXPECT_EQ(c10::numelements_from_dim(10, ints), 1);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
}

View File

@ -2,196 +2,194 @@
#include <gtest/gtest.h>
namespace {
float float_from_bytes(
uint32_t sign,
uint32_t exponent,
uint32_t fraction
) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t bytes;
bytes = 0;
bytes |= sign;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
bytes <<= 8;
bytes |= exponent;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
bytes <<= 23;
bytes |= fraction;
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t bytes;
bytes = 0;
bytes |= sign;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
bytes <<= 8;
bytes |= exponent;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
bytes <<= 23;
bytes |= fraction;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
}
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 0; i < 100; ++i) {
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, NaN) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
EXPECT_TRUE(std::isnan(inNaN));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, NaN) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
EXPECT_TRUE(std::isnan(inNaN));
c10::BFloat16 a = c10::BFloat16(inNaN);
float out = c10::detail::f32_from_bits(a.x);
c10::BFloat16 a = c10::BFloat16(inNaN);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_TRUE(std::isnan(out));
}
EXPECT_TRUE(std::isnan(out));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, Inf) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float inInf = float_from_bytes(0, 0xFF, 0);
EXPECT_TRUE(std::isinf(inInf));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, Inf) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float inInf = float_from_bytes(0, 0xFF, 0);
EXPECT_TRUE(std::isinf(inInf));
c10::BFloat16 a = c10::BFloat16(inInf);
float out = c10::detail::f32_from_bits(a.x);
c10::BFloat16 a = c10::BFloat16(inInf);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_TRUE(std::isinf(out));
}
EXPECT_TRUE(std::isinf(out));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, SmallestDenormal) {
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero subnormal number
c10::BFloat16 a = c10::BFloat16(in);
float out = c10::detail::f32_from_bits(a.x);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Conversion, SmallestDenormal) {
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
// subnormal number
c10::BFloat16 a = c10::BFloat16(in);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_FLOAT_EQ(in, out);
}
EXPECT_FLOAT_EQ(in, out);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, Addition) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after addition, we should have no loss in precision.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, Addition) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after addition, we should have no loss in precision.
// input bits
// S | Exponent | Mantissa
// 0 | 10000000 | 10010000000000000000000 = 3.125
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float input = float_from_bytes(0, 0, 0x40480000);
// input bits
// S | Exponent | Mantissa
// 0 | 10000000 | 10010000000000000000000 = 3.125
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float input = float_from_bytes(0, 0, 0x40480000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000001 | 10010000000000000000000 = 6.25
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float expected = float_from_bytes(0, 0, 0x40c80000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000001 | 10010000000000000000000 = 6.25
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float expected = float_from_bytes(0, 0, 0x40c80000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
b = b + b;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
b = b + b;
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, Subtraction) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after subtraction, we should have no loss in precision.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, Subtraction) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after subtraction, we should have no loss in precision.
// input bits
// S | Exponent | Mantissa
// 0 | 10000001 | 11101000000000000000000 = 7.625
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float input = float_from_bytes(0, 0, 0x40f40000);
// input bits
// S | Exponent | Mantissa
// 0 | 10000001 | 11101000000000000000000 = 7.625
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float input = float_from_bytes(0, 0, 0x40f40000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000000 | 01010000000000000000000 = 2.625
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float expected = float_from_bytes(0, 0, 0x40280000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000000 | 01010000000000000000000 = 2.625
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
float expected = float_from_bytes(0, 0, 0x40280000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
b = b - 5;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
b = b - 5;
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
float BinaryToFloat(uint32_t bytes) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
float BinaryToFloat(uint32_t bytes) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
struct BFloat16TestParam {
uint32_t input;
uint16_t rne;
};
struct BFloat16TestParam {
uint32_t input;
uint16_t rne;
};
class BFloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<BFloat16TestParam> {
};
class BFloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<BFloat16TestParam> {};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_P(BFloat16Test, BFloat16RNETest) {
float value = BinaryToFloat(GetParam().input);
uint16_t rounded = c10::detail::round_to_nearest_even(value);
EXPECT_EQ(GetParam().rne, rounded);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_P(BFloat16Test, BFloat16RNETest) {
float value = BinaryToFloat(GetParam().input);
uint16_t rounded = c10::detail::round_to_nearest_even(value);
EXPECT_EQ(GetParam().rne, rounded);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
INSTANTIATE_TEST_CASE_P(
BFloat16Test_Instantiation, BFloat16Test,
::testing::Values(BFloat16TestParam{0x3F848000, 0x3F84},
BFloat16TestParam{0x3F848010, 0x3F85},
BFloat16TestParam{0x3F850000, 0x3F85},
BFloat16TestParam{0x3F858000, 0x3F86},
BFloat16TestParam{0x3FFF8000, 0x4000}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
INSTANTIATE_TEST_CASE_P(
BFloat16Test_Instantiation,
BFloat16Test,
::testing::Values(
BFloat16TestParam{0x3F848000, 0x3F84},
BFloat16TestParam{0x3F848010, 0x3F85},
BFloat16TestParam{0x3F850000, 0x3F85},
BFloat16TestParam{0x3F858000, 0x3F86},
BFloat16TestParam{0x3FFF8000, 0x4000}));
} // namespace

View File

@ -1,4 +1,5 @@
// Warning: this file is included twice in aten/src/ATen/test/cuda_complex_math_test.cu
// Warning: this file is included twice in
// aten/src/ATen/test/cuda_complex_math_test.cu
#include <c10/util/complex.h>
#include <gtest/gtest.h>
@ -16,152 +17,152 @@
C10_DEFINE_TEST(TestExponential, IPi) {
// exp(i*pi) = -1
{
c10::complex<float> e_i_pi = std::exp(c10::complex<float>(0, float(PI)));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
c10::complex<float> e_i_pi = std::exp(c10::complex<float>(0, float(PI)));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
}
{
c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
}
{
c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
}
{
c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
}
}
C10_DEFINE_TEST(TestExponential, EulerFormula) {
// exp(ix) = cos(x) + i * sin(x)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> e = std::exp(x);
float expected_real = std::exp(x.real()) * std::cos(x.imag());
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> e = std::exp(x);
float expected_real = std::exp(x.real()) * std::cos(x.imag());
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> e = ::exp(x);
float expected_real = ::exp(x.real()) * ::cos(x.imag());
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> e = ::exp(x);
float expected_real = ::exp(x.real()) * ::cos(x.imag());
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> e = std::exp(x);
float expected_real = std::exp(x.real()) * std::cos(x.imag());
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> e = std::exp(x);
float expected_real = std::exp(x.real()) * std::cos(x.imag());
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> e = ::exp(x);
float expected_real = ::exp(x.real()) * ::cos(x.imag());
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> e = ::exp(x);
float expected_real = ::exp(x.real()) * ::cos(x.imag());
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
C10_ASSERT_NEAR(e.real(), expected_real, tol);
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
}
}
C10_DEFINE_TEST(TestLog, Definition) {
// log(x) = log(r) + i*theta
{
c10::complex<float> x(1.2, 3.4);
c10::complex<float> l = std::log(x);
float expected_real = std::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
c10::complex<float> x(1.2, 3.4);
c10::complex<float> l = std::log(x);
float expected_real = std::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
}
{
c10::complex<float> x(1.2, 3.4);
c10::complex<float> l = ::log(x);
float expected_real = ::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
c10::complex<float> x(1.2, 3.4);
c10::complex<float> l = ::log(x);
float expected_real = ::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
}
{
c10::complex<double> x(1.2, 3.4);
c10::complex<double> l = std::log(x);
float expected_real = std::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
c10::complex<double> x(1.2, 3.4);
c10::complex<double> l = std::log(x);
float expected_real = std::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
}
{
c10::complex<double> x(1.2, 3.4);
c10::complex<double> l = ::log(x);
float expected_real = ::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
c10::complex<double> x(1.2, 3.4);
c10::complex<double> l = ::log(x);
float expected_real = ::log(std::abs(x));
float expected_imag = std::arg(x);
C10_ASSERT_NEAR(l.real(), expected_real, tol);
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
}
}
C10_DEFINE_TEST(TestLog10, Rev) {
// log10(10^x) = x
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = std::log10(std::pow(float(10), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = std::log10(std::pow(float(10), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = ::log10(::pow(float(10), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = ::log10(::pow(float(10), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = std::log10(std::pow(double(10), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = std::log10(std::pow(double(10), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = ::log10(::pow(double(10), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = ::log10(::pow(double(10), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
}
}
C10_DEFINE_TEST(TestLog2, Rev) {
// log2(2^x) = x
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = std::log2(std::pow(float(2), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = std::log2(std::pow(float(2), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = ::log2(std::pow(float(2), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l = ::log2(std::pow(float(2), x));
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = std::log2(std::pow(double(2), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = std::log2(std::pow(double(2), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = ::log2(std::pow(double(2), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l = ::log2(std::pow(double(2), x));
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
}
}
@ -170,64 +171,64 @@ C10_DEFINE_TEST(TestLog2, Rev) {
C10_DEFINE_TEST(TestPowSqrt, Equal) {
// x^0.5 = sqrt(x)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::pow(x, float(0.5));
c10::complex<float> z = std::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::pow(x, float(0.5));
c10::complex<float> z = std::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::pow(x, float(0.5));
c10::complex<float> z = ::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::pow(x, float(0.5));
c10::complex<float> z = ::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::pow(x, double(0.5));
c10::complex<double> z = std::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::pow(x, double(0.5));
c10::complex<double> z = std::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::pow(x, double(0.5));
c10::complex<double> z = ::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::pow(x, double(0.5));
c10::complex<double> z = ::sqrt(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
}
C10_DEFINE_TEST(TestPow, Square) {
// x^2 = x * x
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::pow(x, float(2));
c10::complex<float> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::pow(x, float(2));
c10::complex<float> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::pow(x, float(2));
c10::complex<float> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::pow(x, float(2));
c10::complex<float> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::pow(x, double(2));
c10::complex<double> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::pow(x, double(2));
c10::complex<double> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::pow(x, double(2));
c10::complex<double> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::pow(x, double(2));
c10::complex<double> z = x * x;
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
}
@ -237,132 +238,132 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
// sin(x + i * y) = sin(x) * cosh(y) + i * cos(x) * sinh(y)
// cos(x + i * y) = cos(x) * cosh(y) - i * sin(x) * sinh(y)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::sin(x);
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::sin(x);
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::sin(x);
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::sin(x);
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::cos(x);
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::cos(x);
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::cos(x);
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::cos(x);
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::sin(x);
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::sin(x);
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::sin(x);
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::sin(x);
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::cos(x);
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::cos(x);
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::cos(x);
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::cos(x);
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
C10_ASSERT_NEAR(y.real(), expected_real, tol);
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
}
}
C10_DEFINE_TEST(TestTan, Identity) {
// tan(x) = sin(x) / cos(x)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::tan(x);
c10::complex<float> z = std::sin(x) / std::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::tan(x);
c10::complex<float> z = std::sin(x) / std::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::tan(x);
c10::complex<float> z = ::sin(x) / ::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::tan(x);
c10::complex<float> z = ::sin(x) / ::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::tan(x);
c10::complex<double> z = std::sin(x) / std::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::tan(x);
c10::complex<double> z = std::sin(x) / std::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::tan(x);
c10::complex<double> z = ::sin(x) / ::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::tan(x);
c10::complex<double> z = ::sin(x) / ::cos(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
}
C10_DEFINE_TEST(TestTanh, Identity) {
// tanh(x) = sinh(x) / cosh(x)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::tanh(x);
c10::complex<float> z = std::sinh(x) / std::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = std::tanh(x);
c10::complex<float> z = std::sinh(x) / std::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::tanh(x);
c10::complex<float> z = ::sinh(x) / ::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<float> x(0.1, 1.2);
c10::complex<float> y = ::tanh(x);
c10::complex<float> z = ::sinh(x) / ::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::tanh(x);
c10::complex<double> z = std::sinh(x) / std::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = std::tanh(x);
c10::complex<double> z = std::sinh(x) / std::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::tanh(x);
c10::complex<double> z = ::sinh(x) / ::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
c10::complex<double> x(0.1, 1.2);
c10::complex<double> y = ::tanh(x);
c10::complex<double> z = ::sinh(x) / ::cosh(x);
C10_ASSERT_NEAR(y.real(), z.real(), tol);
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
}
}
@ -373,64 +374,64 @@ C10_DEFINE_TEST(TestRevTrigonometric, Rev) {
// acos(cos(x)) = x
// atan(tan(x)) = x
{
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = std::sin(x);
c10::complex<float> ss = std::asin(s);
c10::complex<float> c = std::cos(x);
c10::complex<float> cc = std::acos(c);
c10::complex<float> t = std::tan(x);
c10::complex<float> tt = std::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = std::sin(x);
c10::complex<float> ss = std::asin(s);
c10::complex<float> c = std::cos(x);
c10::complex<float> cc = std::acos(c);
c10::complex<float> t = std::tan(x);
c10::complex<float> tt = std::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = ::sin(x);
c10::complex<float> ss = ::asin(s);
c10::complex<float> c = ::cos(x);
c10::complex<float> cc = ::acos(c);
c10::complex<float> t = ::tan(x);
c10::complex<float> tt = ::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = ::sin(x);
c10::complex<float> ss = ::asin(s);
c10::complex<float> c = ::cos(x);
c10::complex<float> cc = ::acos(c);
c10::complex<float> t = ::tan(x);
c10::complex<float> tt = ::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = std::sin(x);
c10::complex<double> ss = std::asin(s);
c10::complex<double> c = std::cos(x);
c10::complex<double> cc = std::acos(c);
c10::complex<double> t = std::tan(x);
c10::complex<double> tt = std::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = std::sin(x);
c10::complex<double> ss = std::asin(s);
c10::complex<double> c = std::cos(x);
c10::complex<double> cc = std::acos(c);
c10::complex<double> t = std::tan(x);
c10::complex<double> tt = std::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = ::sin(x);
c10::complex<double> ss = ::asin(s);
c10::complex<double> c = ::cos(x);
c10::complex<double> cc = ::acos(c);
c10::complex<double> t = ::tan(x);
c10::complex<double> tt = ::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = ::sin(x);
c10::complex<double> ss = ::asin(s);
c10::complex<double> c = ::cos(x);
c10::complex<double> cc = ::acos(c);
c10::complex<double> t = ::tan(x);
c10::complex<double> tt = ::atan(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
}
@ -441,63 +442,63 @@ C10_DEFINE_TEST(TestRevHyperbolic, Rev) {
// acosh(cosh(x)) = x
// atanh(tanh(x)) = x
{
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = std::sinh(x);
c10::complex<float> ss = std::asinh(s);
c10::complex<float> c = std::cosh(x);
c10::complex<float> cc = std::acosh(c);
c10::complex<float> t = std::tanh(x);
c10::complex<float> tt = std::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = std::sinh(x);
c10::complex<float> ss = std::asinh(s);
c10::complex<float> c = std::cosh(x);
c10::complex<float> cc = std::acosh(c);
c10::complex<float> t = std::tanh(x);
c10::complex<float> tt = std::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = ::sinh(x);
c10::complex<float> ss = ::asinh(s);
c10::complex<float> c = ::cosh(x);
c10::complex<float> cc = ::acosh(c);
c10::complex<float> t = ::tanh(x);
c10::complex<float> tt = ::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<float> x(0.5, 0.6);
c10::complex<float> s = ::sinh(x);
c10::complex<float> ss = ::asinh(s);
c10::complex<float> c = ::cosh(x);
c10::complex<float> cc = ::acosh(c);
c10::complex<float> t = ::tanh(x);
c10::complex<float> tt = ::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = std::sinh(x);
c10::complex<double> ss = std::asinh(s);
c10::complex<double> c = std::cosh(x);
c10::complex<double> cc = std::acosh(c);
c10::complex<double> t = std::tanh(x);
c10::complex<double> tt = std::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = std::sinh(x);
c10::complex<double> ss = std::asinh(s);
c10::complex<double> c = std::cosh(x);
c10::complex<double> cc = std::acosh(c);
c10::complex<double> t = std::tanh(x);
c10::complex<double> tt = std::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
{
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = ::sinh(x);
c10::complex<double> ss = ::asinh(s);
c10::complex<double> c = ::cosh(x);
c10::complex<double> cc = ::acosh(c);
c10::complex<double> t = ::tanh(x);
c10::complex<double> tt = ::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
c10::complex<double> x(0.5, 0.6);
c10::complex<double> s = ::sinh(x);
c10::complex<double> ss = ::asinh(s);
c10::complex<double> c = ::cosh(x);
c10::complex<double> cc = ::acosh(c);
c10::complex<double> t = ::tanh(x);
c10::complex<double> tt = ::atanh(t);
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
}
}

View File

@ -1,10 +1,10 @@
#include <type_traits>
#include <tuple>
#include <sstream>
#include <c10/util/complex.h>
#include <c10/macros/Macros.h>
#include <c10/util/complex.h>
#include <c10/util/hash.h>
#include <gtest/gtest.h>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#if (defined(__CUDACC__) || defined(__HIPCC__))
@ -34,71 +34,72 @@ MAYBE_GLOBAL void test_pod() {
TEST(TestMemory, ReinterpretCast) {
{
std::complex<float> z(1, 2);
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(1));
ASSERT_EQ(zz.imag(), float(2));
std::complex<float> z(1, 2);
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(1));
ASSERT_EQ(zz.imag(), float(2));
}
{
c10::complex<float> z(3, 4);
std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(3));
ASSERT_EQ(zz.imag(), float(4));
c10::complex<float> z(3, 4);
std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(3));
ASSERT_EQ(zz.imag(), float(4));
}
{
std::complex<double> z(1, 2);
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(1));
ASSERT_EQ(zz.imag(), double(2));
std::complex<double> z(1, 2);
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(1));
ASSERT_EQ(zz.imag(), double(2));
}
{
c10::complex<double> z(3, 4);
std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(3));
ASSERT_EQ(zz.imag(), double(4));
c10::complex<double> z(3, 4);
std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(3));
ASSERT_EQ(zz.imag(), double(4));
}
}
#if defined(__CUDACC__) || defined(__HIPCC__)
TEST(TestMemory, ThrustReinterpretCast) {
{
thrust::complex<float> z(1, 2);
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(1));
ASSERT_EQ(zz.imag(), float(2));
thrust::complex<float> z(1, 2);
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(1));
ASSERT_EQ(zz.imag(), float(2));
}
{
c10::complex<float> z(3, 4);
thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(3));
ASSERT_EQ(zz.imag(), float(4));
c10::complex<float> z(3, 4);
thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
ASSERT_EQ(zz.real(), float(3));
ASSERT_EQ(zz.imag(), float(4));
}
{
thrust::complex<double> z(1, 2);
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(1));
ASSERT_EQ(zz.imag(), double(2));
thrust::complex<double> z(1, 2);
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(1));
ASSERT_EQ(zz.imag(), double(2));
}
{
c10::complex<double> z(3, 4);
thrust::complex<double> zz = *reinterpret_cast<thrust::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(3));
ASSERT_EQ(zz.imag(), double(4));
c10::complex<double> z(3, 4);
thrust::complex<double> zz =
*reinterpret_cast<thrust::complex<double>*>(&z);
ASSERT_EQ(zz.real(), double(3));
ASSERT_EQ(zz.imag(), double(4));
}
}
#endif
} // memory
} // namespace memory
namespace constructors {
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_construct_from_scalar() {
constexpr scalar_t num1 = scalar_t(1.23);
constexpr scalar_t num2 = scalar_t(4.56);
@ -111,29 +112,48 @@ C10_HOST_DEVICE void test_construct_from_scalar() {
static_assert(c10::complex<scalar_t>().imag() == zero, "");
}
template<typename scalar_t, typename other_t>
template <typename scalar_t, typename other_t>
C10_HOST_DEVICE void test_construct_from_other() {
constexpr other_t num1 = other_t(1.23);
constexpr other_t num2 = other_t(4.56);
constexpr scalar_t num3 = scalar_t(num1);
constexpr scalar_t num4 = scalar_t(num2);
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3, "");
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4, "");
static_assert(
c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3,
"");
static_assert(
c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4,
"");
}
MAYBE_GLOBAL void test_convert_constructors() {
test_construct_from_scalar<float>();
test_construct_from_scalar<double>();
static_assert(std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
static_assert(!std::is_convertible<c10::complex<double>, c10::complex<float>>::value, "");
static_assert(std::is_convertible<c10::complex<float>, c10::complex<double>>::value, "");
static_assert(std::is_convertible<c10::complex<double>, c10::complex<double>>::value, "");
static_assert(
std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
static_assert(
!std::is_convertible<c10::complex<double>, c10::complex<float>>::value,
"");
static_assert(
std::is_convertible<c10::complex<float>, c10::complex<double>>::value,
"");
static_assert(
std::is_convertible<c10::complex<double>, c10::complex<double>>::value,
"");
static_assert(std::is_constructible<c10::complex<float>, c10::complex<float>>::value, "");
static_assert(std::is_constructible<c10::complex<double>, c10::complex<float>>::value, "");
static_assert(std::is_constructible<c10::complex<float>, c10::complex<double>>::value, "");
static_assert(std::is_constructible<c10::complex<double>, c10::complex<double>>::value, "");
static_assert(
std::is_constructible<c10::complex<float>, c10::complex<float>>::value,
"");
static_assert(
std::is_constructible<c10::complex<double>, c10::complex<float>>::value,
"");
static_assert(
std::is_constructible<c10::complex<float>, c10::complex<double>>::value,
"");
static_assert(
std::is_constructible<c10::complex<double>, c10::complex<double>>::value,
"");
test_construct_from_other<float, float>();
test_construct_from_other<float, double>();
@ -141,12 +161,16 @@ MAYBE_GLOBAL void test_convert_constructors() {
test_construct_from_other<double, double>();
}
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_construct_from_std() {
constexpr scalar_t num1 = scalar_t(1.23);
constexpr scalar_t num2 = scalar_t(4.56);
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1, "");
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2, "");
static_assert(
c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1,
"");
static_assert(
c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2,
"");
}
MAYBE_GLOBAL void test_std_conversion() {
@ -155,12 +179,16 @@ MAYBE_GLOBAL void test_std_conversion() {
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template<typename scalar_t>
template <typename scalar_t>
void test_construct_from_thrust() {
constexpr scalar_t num1 = scalar_t(1.23);
constexpr scalar_t num2 = scalar_t(4.56);
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(), num1);
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(), num2);
ASSERT_EQ(
c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(),
num1);
ASSERT_EQ(
c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(),
num2);
}
TEST(TestConstructors, FromThrust) {
@ -170,7 +198,11 @@ TEST(TestConstructors, FromThrust) {
#endif
TEST(TestConstructors, UnorderedMap) {
std::unordered_map<c10::complex<double>, c10::complex<double>, c10::hash<c10::complex<double>>> m;
std::unordered_map<
c10::complex<double>,
c10::complex<double>,
c10::hash<c10::complex<double>>>
m;
auto key1 = c10::complex<double>(2.5, 3);
auto key2 = c10::complex<double>(2, 0);
auto val1 = c10::complex<double>(2, -3.2);
@ -181,11 +213,11 @@ TEST(TestConstructors, UnorderedMap) {
ASSERT_EQ(m[key2], val2);
}
} // constructors
} // namespace constructors
namespace assignment {
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> one() {
c10::complex<scalar_t> result(3, 4);
result = scalar_t(1);
@ -232,7 +264,8 @@ MAYBE_GLOBAL void test_assign_std() {
}
#if defined(__CUDACC__) || defined(__HIPCC__)
C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>> one_two_thrust() {
C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>>
one_two_thrust() {
thrust::complex<float> src(1, 2);
c10::complex<double> ret0;
c10::complex<float> ret1;
@ -258,7 +291,8 @@ MAYBE_GLOBAL void test_complex_literals() {
static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "");
static_assert((0.5_if).real() == float(), "");
static_assert((0.5_if).imag() == float(0.5), "");
static_assert(std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
static_assert(
std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
static_assert((0.5_id).real() == float(), "");
static_assert((0.5_id).imag() == float(0.5), "");
@ -274,14 +308,14 @@ MAYBE_GLOBAL void test_complex_literals() {
namespace real_imag {
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> zero_one() {
c10::complex<scalar_t> result;
result.imag(scalar_t(1));
return result;
}
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> one_zero() {
c10::complex<scalar_t> result;
result.real(scalar_t(1));
@ -304,35 +338,35 @@ MAYBE_GLOBAL void test_real_imag_modify() {
namespace arithmetic_assign {
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> p(scalar_t value) {
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
result += value;
return result;
}
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> m(scalar_t value) {
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
result -= value;
return result;
}
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> t(scalar_t value) {
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
result *= value;
return result;
}
template<typename scalar_t>
template <typename scalar_t>
constexpr c10::complex<scalar_t> d(scalar_t value) {
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
result /= value;
return result;
}
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
constexpr c10::complex<scalar_t> x = p(scalar_t(1));
static_assert(x.real() == scalar_t(3), "");
@ -348,35 +382,47 @@ C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
static_assert(t.imag() == scalar_t(1), "");
}
template<typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> p(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
template <typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> p(
scalar_t real,
scalar_t imag,
c10::complex<rhs_t> rhs) {
c10::complex<scalar_t> result(real, imag);
result += rhs;
return result;
}
template<typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> m(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
template <typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> m(
scalar_t real,
scalar_t imag,
c10::complex<rhs_t> rhs) {
c10::complex<scalar_t> result(real, imag);
result -= rhs;
return result;
}
template<typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> t(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
template <typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> t(
scalar_t real,
scalar_t imag,
c10::complex<rhs_t> rhs) {
c10::complex<scalar_t> result(real, imag);
result *= rhs;
return result;
}
template<typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> d(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
template <typename scalar_t, typename rhs_t>
constexpr c10::complex<scalar_t> d(
scalar_t real,
scalar_t imag,
c10::complex<rhs_t> rhs) {
c10::complex<scalar_t> result(real, imag);
result /= rhs;
return result;
}
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_arithmetic_assign_complex() {
using namespace c10::complex_literals;
constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if);
@ -429,26 +475,64 @@ MAYBE_GLOBAL void test_arithmetic_assign() {
namespace arithmetic {
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_arithmetic_() {
static_assert(c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
static_assert(c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
static_assert(
c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
static_assert(
c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
static_assert(c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(4, 6), "");
static_assert(c10::complex<scalar_t>(1, 2) + scalar_t(3) == c10::complex<scalar_t>(4, 2), "");
static_assert(scalar_t(3) + c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(4, 2), "");
static_assert(
c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) ==
c10::complex<scalar_t>(4, 6),
"");
static_assert(
c10::complex<scalar_t>(1, 2) + scalar_t(3) ==
c10::complex<scalar_t>(4, 2),
"");
static_assert(
scalar_t(3) + c10::complex<scalar_t>(1, 2) ==
c10::complex<scalar_t>(4, 2),
"");
static_assert(c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-2, -2), "");
static_assert(c10::complex<scalar_t>(1, 2) - scalar_t(3) == c10::complex<scalar_t>(-2, 2), "");
static_assert(scalar_t(3) - c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(2, -2), "");
static_assert(
c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) ==
c10::complex<scalar_t>(-2, -2),
"");
static_assert(
c10::complex<scalar_t>(1, 2) - scalar_t(3) ==
c10::complex<scalar_t>(-2, 2),
"");
static_assert(
scalar_t(3) - c10::complex<scalar_t>(1, 2) ==
c10::complex<scalar_t>(2, -2),
"");
static_assert(c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-5, 10), "");
static_assert(c10::complex<scalar_t>(1, 2) * scalar_t(3) == c10::complex<scalar_t>(3, 6), "");
static_assert(scalar_t(3) * c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(3, 6), "");
static_assert(
c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) ==
c10::complex<scalar_t>(-5, 10),
"");
static_assert(
c10::complex<scalar_t>(1, 2) * scalar_t(3) ==
c10::complex<scalar_t>(3, 6),
"");
static_assert(
scalar_t(3) * c10::complex<scalar_t>(1, 2) ==
c10::complex<scalar_t>(3, 6),
"");
static_assert(c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(1, 2), "");
static_assert(c10::complex<scalar_t>(5, 10) / scalar_t(5) == c10::complex<scalar_t>(1, 2), "");
static_assert(scalar_t(25) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(3, -4), "");
static_assert(
c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) ==
c10::complex<scalar_t>(1, 2),
"");
static_assert(
c10::complex<scalar_t>(5, 10) / scalar_t(5) ==
c10::complex<scalar_t>(1, 2),
"");
static_assert(
scalar_t(25) / c10::complex<scalar_t>(3, 4) ==
c10::complex<scalar_t>(3, -4),
"");
}
MAYBE_GLOBAL void test_arithmetic() {
@ -456,7 +540,7 @@ MAYBE_GLOBAL void test_arithmetic() {
test_arithmetic_<double>();
}
template<typename T, typename int_t>
template <typename T, typename int_t>
void test_binary_ops_for_int_type_(T real, T img, int_t num) {
c10::complex<T> c(real, img);
ASSERT_EQ(c + num, c10::complex<T>(real + num, img));
@ -466,10 +550,12 @@ void test_binary_ops_for_int_type_(T real, T img, int_t num) {
ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num));
ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img));
ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num));
ASSERT_EQ(num / c, c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
ASSERT_EQ(
num / c,
c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
}
template<typename T>
template <typename T>
void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) {
test_binary_ops_for_int_type_<T, int8_t>(real, img, i);
test_binary_ops_for_int_type_<T, int16_t>(real, img, i);
@ -486,12 +572,14 @@ TEST(TestArithmeticIntScalar, All) {
namespace equality {
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_equality_() {
static_assert(c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
static_assert(
c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "");
static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "");
static_assert(c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
static_assert(
c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "");
static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "");
}
@ -505,7 +593,7 @@ MAYBE_GLOBAL void test_equality() {
namespace io {
template<typename scalar_t>
template <typename scalar_t>
void test_io_() {
std::stringstream ss;
c10::complex<scalar_t> a(1, 2);
@ -525,14 +613,16 @@ TEST(TestIO, All) {
namespace test_std {
template<typename scalar_t>
template <typename scalar_t>
C10_HOST_DEVICE void test_callable_() {
static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "");
static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "");
std::abs(c10::complex<scalar_t>(1, 2));
std::arg(c10::complex<scalar_t>(1, 2));
static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "");
static_assert(std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4), "");
static_assert(
std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4),
"");
c10::polar(float(1), float(PI / 2));
c10::polar(double(1), double(PI / 2));
}
@ -542,11 +632,15 @@ MAYBE_GLOBAL void test_callable() {
test_callable_<double>();
}
template<typename scalar_t>
template <typename scalar_t>
void test_values_() {
ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5));
ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6);
ASSERT_LT(std::abs(c10::polar(scalar_t(1), scalar_t(PI / 2)) - c10::complex<scalar_t>(0, 1)), 1e-6);
ASSERT_LT(
std::abs(
c10::polar(scalar_t(1), scalar_t(PI / 2)) -
c10::complex<scalar_t>(0, 1)),
1e-6);
}
TEST(TestStd, BasicFunctions) {
@ -554,8 +648,11 @@ TEST(TestStd, BasicFunctions) {
test_values_<double>();
// CSQRT edge cases: checks for overflows which are likely to occur
// if square root is computed using polar form
ASSERT_LT(std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
ASSERT_LT(std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()), 3e-4);
ASSERT_LT(
std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
ASSERT_LT(
std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()),
3e-4);
}
} // namespace test_std

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ bool throw_func() {
throw std::runtime_error("I'm throwing...");
}
template<class Functor>
template <class Functor>
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
try {
std::forward<Functor>(functor)();
@ -18,7 +18,7 @@ inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
return;
}
ADD_FAILURE() << "Expected to throw exception with message \""
<< expectedMessage << "\" but didn't throw";
<< expectedMessage << "\" but didn't throw";
}
} // namespace
@ -43,30 +43,32 @@ TEST(WarningTest, JustPrintWarning) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(ExceptionTest, ErrorFormatting) {
expectThrowsEq([]() {
TORCH_CHECK(false, "This is invalid");
}, "This is invalid");
expectThrowsEq(
[]() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid");
expectThrowsEq([]() {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
}, "This is invalid (While checking X)");
expectThrowsEq(
[]() {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
},
"This is invalid (While checking X)");
expectThrowsEq([]() {
try {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
} catch (Error& e) {
TORCH_RETHROW(e, "While checking Y");
}
},
R"msg(This is invalid
expectThrowsEq(
[]() {
try {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
} catch (Error& e) {
TORCH_RETHROW(e, "While checking Y");
}
},
R"msg(This is invalid
While checking X
While checking Y)msg");
}
@ -94,5 +96,6 @@ TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_ANY_THROW(failInternalAssert());
EXPECT_EQ(assertionArgumentCounter, 2) << "TORCH_INTERNAL_ASSERT called argument twice";
EXPECT_EQ(assertionArgumentCounter, 2)
<< "TORCH_INTERNAL_ASSERT called argument twice";
}

View File

@ -67,7 +67,8 @@ class ChildDestructableMock final : public DestructableMock {
class NullType1 final {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static SomeClass singleton_;
public:
public:
static constexpr SomeClass* singleton() {
return &singleton_;
}
@ -77,7 +78,8 @@ SomeClass NullType1::singleton_;
class NullType2 final {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static SomeClass singleton_;
public:
public:
static constexpr SomeClass* singleton() {
return &singleton_;
}
@ -934,7 +936,8 @@ TEST(IntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) {
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
}
@ -948,7 +951,8 @@ TEST(
givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = std::move(obj);
EXPECT_FALSE(resourcesReleased);
@ -964,7 +968,8 @@ TEST(
givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> obj2 = std::move(obj);
EXPECT_FALSE(resourcesReleased);
@ -981,7 +986,8 @@ TEST(IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) {
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = std::move(obj);
@ -999,7 +1005,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = std::move(obj);
@ -1017,7 +1024,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1040,8 +1048,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 =
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1064,7 +1072,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1085,7 +1094,8 @@ TEST(
bool dummy = false;
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = make_intrusive<DestructableMock>(&dummy, &dummy);
obj2 = std::move(obj);
@ -1103,7 +1113,8 @@ TEST(
bool dummy = false;
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = make_intrusive<DestructableMock>(&dummy, &dummy);
obj2 = std::move(obj);
@ -1121,7 +1132,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
intrusive_ptr<DestructableMock> copy = obj;
@ -1142,7 +1154,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy = obj;
EXPECT_FALSE(resourcesReleased);
@ -1162,7 +1175,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
intrusive_ptr<DestructableMock> copy = obj;
obj.reset();
EXPECT_FALSE(resourcesReleased);
@ -1179,7 +1193,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
intrusive_ptr<DestructableMock> copy = obj;
obj.reset();
EXPECT_FALSE(resourcesReleased);
@ -1197,7 +1212,8 @@ TEST(
bool wasDestructed = false;
bool dummy = false;
{
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy =
make_intrusive<DestructableMock>(&dummy, &dummy);
@ -1220,7 +1236,8 @@ TEST(
bool wasDestructed = false;
bool dummy = false;
{
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy =
make_intrusive<DestructableMock>(&dummy, &dummy);
@ -1245,7 +1262,8 @@ TEST(
{
auto copy = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
copy = obj;
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
@ -1267,8 +1285,8 @@ TEST(
{
auto copy = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj =
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
copy = obj;
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
@ -1287,7 +1305,8 @@ TEST(IntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) {
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = obj;
@ -1305,7 +1324,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = obj;
@ -1323,7 +1343,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1346,8 +1367,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 =
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 = make_intrusive<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1370,7 +1391,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_FALSE(resourcesReleased);
@ -1388,7 +1410,8 @@ TEST(
TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj.reset();
@ -1402,7 +1425,8 @@ TEST(
givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj;
obj.reset();
@ -1420,7 +1444,8 @@ TEST(
givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj;
copy.reset();
@ -1438,7 +1463,8 @@ TEST(
givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto moved = std::move(obj);
// NOLINTNEXTLINE(bugprone-use-after-move)
@ -1457,7 +1483,8 @@ TEST(
givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto moved = std::move(obj);
moved.reset();
@ -1800,9 +1827,7 @@ TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenDoesntCrash) {
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(
IntrusivePtrTest,
givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) {
TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) {
bool resourcesReleased = false;
bool wasDestructed = false;
{
@ -1840,7 +1865,9 @@ weak_intrusive_ptr<T> make_weak_only(Args&&... args) {
auto intrusive = make_intrusive<T>(std::forward<Args>(args)...);
return weak_intrusive_ptr<T>(intrusive);
}
template <class T, class NullType = c10::detail::intrusive_target_default_null_type<T>>
template <
class T,
class NullType = c10::detail::intrusive_target_default_null_type<T>>
weak_intrusive_ptr<T, NullType> make_invalid_weak() {
return weak_intrusive_ptr<T, NullType>(intrusive_ptr<T, NullType>());
}
@ -1903,9 +1930,7 @@ TEST(
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(
WeakIntrusivePtrTest,
vector_insert_weak_intrusive) {
TEST(WeakIntrusivePtrTest, vector_insert_weak_intrusive) {
std::vector<weak_intrusive_ptr<SomeClass>> priorWorks;
std::vector<intrusive_ptr<SomeClass>> wips;
wips.push_back(make_intrusive<SomeClass>());
@ -2139,8 +2164,10 @@ TEST(
TEST(
WeakIntrusivePtrTest,
givenNullPtr_whenMoveAssigningToDifferentNullptr_thenHasNewNullptr) {
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 = make_invalid_weak<SomeClass, NullType2>();
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 =
make_invalid_weak<SomeClass, NullType2>();
obj2 = std::move(obj1);
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
// NOLINTNEXTLINE(bugprone-use-after-move)
@ -2362,8 +2389,10 @@ TEST(
TEST(
WeakIntrusivePtrTest,
givenNullPtr_whenCopyAssigningToDifferentNullptr_thenHasNewNullptr) {
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 = make_invalid_weak<SomeClass, NullType2>();
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 =
make_invalid_weak<SomeClass, NullType2>();
obj2 = obj1;
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
EXPECT_TRUE(obj1.expired());
@ -2468,7 +2497,8 @@ TEST(
TEST(
WeakIntrusivePtrTest,
givenNullPtr_whenMoveConstructingToDifferentNullptr_thenHasNewNullptr) {
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 = std::move(obj1);
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
// NOLINTNEXTLINE(bugprone-use-after-move)
@ -2575,7 +2605,8 @@ TEST(
TEST(
WeakIntrusivePtrTest,
givenNullPtr_whenCopyConstructingToDifferentNullptr_thenHasNewNullptr) {
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
make_invalid_weak<SomeClass, NullType1>();
weak_intrusive_ptr<SomeClass, NullType2> obj2 = obj1;
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
EXPECT_TRUE(obj1.expired());
@ -3175,7 +3206,8 @@ TEST(
givenPtr_whenLastStrongPointerResets_thenReleasesResources) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj.ptr.reset();
@ -3192,7 +3224,8 @@ TEST(
givenPtr_whenDestructedButStillHasStrongPointers_thenDoesntReleaseResources) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_FALSE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj.weak.reset();
@ -3208,7 +3241,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) {
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
}
@ -3222,7 +3256,8 @@ TEST(
givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = std::move(obj);
EXPECT_TRUE(resourcesReleased);
@ -3238,7 +3273,8 @@ TEST(
givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> obj2 = std::move(obj);
EXPECT_TRUE(resourcesReleased);
@ -3255,7 +3291,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) {
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = std::move(obj);
@ -3273,7 +3310,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = std::move(obj);
@ -3291,7 +3329,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3314,8 +3353,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 =
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3338,7 +3377,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3359,7 +3399,8 @@ TEST(
bool dummy = false;
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = make_weak_only<DestructableMock>(&dummy, &dummy);
obj2 = std::move(obj);
@ -3377,7 +3418,8 @@ TEST(
bool dummy = false;
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
{
auto obj2 = make_weak_only<DestructableMock>(&dummy, &dummy);
obj2 = std::move(obj);
@ -3395,7 +3437,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
weak_intrusive_ptr<DestructableMock> copy = obj;
@ -3416,7 +3459,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy = obj;
EXPECT_TRUE(resourcesReleased);
@ -3436,7 +3480,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
weak_intrusive_ptr<DestructableMock> copy = obj;
obj.reset();
EXPECT_TRUE(resourcesReleased);
@ -3453,7 +3498,8 @@ TEST(
bool resourcesReleased = false;
bool wasDestructed = false;
{
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
weak_intrusive_ptr<DestructableMock> copy = obj;
obj.reset();
EXPECT_TRUE(resourcesReleased);
@ -3471,7 +3517,8 @@ TEST(
bool wasDestructed = false;
bool dummy = false;
{
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy =
make_weak_only<DestructableMock>(&dummy, &dummy);
@ -3494,7 +3541,8 @@ TEST(
bool wasDestructed = false;
bool dummy = false;
{
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy =
make_weak_only<DestructableMock>(&dummy, &dummy);
@ -3519,7 +3567,8 @@ TEST(
{
auto copy = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
copy = obj;
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
@ -3541,8 +3590,8 @@ TEST(
{
auto copy = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj =
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
copy = obj;
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
@ -3561,7 +3610,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) {
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = obj;
@ -3579,7 +3629,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj2 = obj;
@ -3597,7 +3648,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3620,8 +3672,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 =
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 = make_weak_only<ChildDestructableMock>(
&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3644,7 +3696,8 @@ TEST(
bool wasDestructed = false;
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
{
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj2 =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
weak_intrusive_ptr<DestructableMock> copy = obj2;
EXPECT_TRUE(resourcesReleased);
@ -3662,7 +3715,8 @@ TEST(
TEST(WeakIntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
EXPECT_TRUE(resourcesReleased);
EXPECT_FALSE(wasDestructed);
obj.reset();
@ -3676,7 +3730,8 @@ TEST(
givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj;
obj.reset();
@ -3694,7 +3749,8 @@ TEST(
givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto copy = obj;
copy.reset();
@ -3712,7 +3768,8 @@ TEST(
givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto moved = std::move(obj);
// NOLINTNEXTLINE(bugprone-use-after-move)
@ -3731,7 +3788,8 @@ TEST(
givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) {
bool resourcesReleased = false;
bool wasDestructed = false;
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
auto obj =
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
{
auto moved = std::move(obj);
moved.reset();

View File

@ -8,58 +8,58 @@ using namespace ::testing;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(irange_test, range_test) {
std::vector<int> test_vec;
for(const auto i : c10::irange(4, 11)){
test_vec.push_back(i);
}
const std::vector<int> correct = {{4,5,6,7,8,9,10}};
ASSERT_EQ(test_vec, correct);
std::vector<int> test_vec;
for (const auto i : c10::irange(4, 11)) {
test_vec.push_back(i);
}
const std::vector<int> correct = {{4, 5, 6, 7, 8, 9, 10}};
ASSERT_EQ(test_vec, correct);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(irange_test, end_test) {
std::vector<int> test_vec;
for(const auto i : c10::irange(5)){
test_vec.push_back(i);
}
const std::vector<int> correct = {{0, 1, 2, 3, 4}};
ASSERT_EQ(test_vec, correct);
std::vector<int> test_vec;
for (const auto i : c10::irange(5)) {
test_vec.push_back(i);
}
const std::vector<int> correct = {{0, 1, 2, 3, 4}};
ASSERT_EQ(test_vec, correct);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(irange_test, neg_range_test) {
std::vector<int> test_vec;
for(const auto i : c10::irange(-2, 3)){
test_vec.push_back(i);
}
const std::vector<int> correct = {{-2,-1,0,1,2}};
ASSERT_EQ(test_vec, correct);
std::vector<int> test_vec;
for (const auto i : c10::irange(-2, 3)) {
test_vec.push_back(i);
}
const std::vector<int> correct = {{-2, -1, 0, 1, 2}};
ASSERT_EQ(test_vec, correct);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(irange, empty_reverse_range_two_inputs){
std::vector<int> test_vec;
for(const auto i : c10::irange(3, -3)){
test_vec.push_back(i);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if(i>20){ //Cap the number of elements we add if something goes wrong
break;
}
TEST(irange, empty_reverse_range_two_inputs) {
std::vector<int> test_vec;
for (const auto i : c10::irange(3, -3)) {
test_vec.push_back(i);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (i > 20) { // Cap the number of elements we add if something goes wrong
break;
}
const std::vector<int> correct = {};
ASSERT_EQ(test_vec, correct);
}
const std::vector<int> correct = {};
ASSERT_EQ(test_vec, correct);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(irange, empty_reverse_range_one_input){
std::vector<int> test_vec;
for(const auto i : c10::irange(-3)){
test_vec.push_back(i);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if(i>20){ //Cap the number of elements we add if something goes wrong
break;
}
TEST(irange, empty_reverse_range_one_input) {
std::vector<int> test_vec;
for (const auto i : c10::irange(-3)) {
test_vec.push_back(i);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (i > 20) { // Cap the number of elements we add if something goes wrong
break;
}
const std::vector<int> correct = {};
ASSERT_EQ(test_vec, correct);
}
const std::vector<int> correct = {};
ASSERT_EQ(test_vec, correct);
}

View File

@ -1,8 +1,8 @@
#include <algorithm>
#include <gtest/gtest.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Logging.h>
#include <gtest/gtest.h>
namespace c10_test {
@ -54,15 +54,15 @@ TEST(LoggingTest, TestEnforceEquals) {
namespace {
struct EnforceEqWithCaller {
void test(const char *x) {
void test(const char* x) {
CAFFE_ENFORCE_EQ_WITH_CALLER(1, 1, "variable: ", x, " is a variable");
}
};
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LoggingTest, TestEnforceMessageVariables) {
const char *const x = "hello";
const char* const x = "hello";
CAFFE_ENFORCE_EQ(1, 1, "variable: ", x, " is a variable");
EnforceEqWithCaller e;
@ -70,7 +70,9 @@ TEST(LoggingTest, TestEnforceMessageVariables) {
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LoggingTest, EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) {
TEST(
LoggingTest,
EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) {
std::vector<int> x = {1, 2, 3, 4};
// This case is a little tricky. We have a temporary
// std::initializer_list to which our temporary ArrayRef
@ -102,7 +104,7 @@ std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
out << "Noncopyable(" << nc.x << ")";
return out;
}
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LoggingTest, DoesntCopyComparedObjects) {
@ -132,7 +134,8 @@ TEST(LoggingTest, EnforceShowcase) {
WRAP_AND_PRINT(CAFFE_ENFORCE_EQ(
one * two + three, three * two, "It's a pretty complicated expression"));
WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(std::equal_to<void>(), ==, one * two + three, three * two));
WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(
std::equal_to<void>(), ==, one * two + three, three * two));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -17,32 +17,29 @@ class OptionalTest : public ::testing::Test {
template <typename T>
T getSampleValue();
template<>
template <>
bool getSampleValue() {
return true;
}
template<>
template <>
uint64_t getSampleValue() {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
return 42;
}
template<>
template <>
std::string getSampleValue() {
return "hello";
}
using OptionalTypes = ::testing::Types<
// 32-bit scalar optimization.
bool,
// Trivially destructible but not 32-bit scalar.
uint64_t,
// Non-trivial destructor.
std::string
>;
// 32-bit scalar optimization.
bool,
// Trivially destructible but not 32-bit scalar.
uint64_t,
// Non-trivial destructor.
std::string>;
TYPED_TEST_CASE(OptionalTest, OptionalTypes);
@ -71,7 +68,8 @@ TYPED_TEST(OptionalTest, Initialized) {
moveAssign = std::move(moveFrom2);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::array<typename TestFixture::optional *, 5> opts = {&opt, &copy, &copyAssign, &move, &moveAssign};
std::array<typename TestFixture::optional*, 5> opts = {
&opt, &copy, &copyAssign, &move, &moveAssign};
for (auto* popt : opts) {
auto& opt = *popt;
EXPECT_TRUE((bool)opt);

View File

@ -1,6 +1,6 @@
#include <vector>
#include <unordered_set>
#include <algorithm>
#include <unordered_set>
#include <vector>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
@ -11,7 +11,8 @@ namespace {
#define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2);
using dict_int_int = ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>;
using dict_int_int =
ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>;
dict_int_int test_dict(dict_int_int& dict) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
@ -20,7 +21,7 @@ dict_int_int test_dict(dict_int_int& dict) {
}
int64_t i = 0;
for (auto entry: dict) {
for (auto entry : dict) {
TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1);
++i;
}
@ -28,7 +29,7 @@ dict_int_int test_dict(dict_int_int& dict) {
// erase a few entries by themselves
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
std::unordered_set<int64_t> erase_set = {0, 2, 9, 71};
for (auto erase: erase_set) {
for (auto erase : erase_set) {
dict.erase(erase);
}
@ -55,7 +56,7 @@ dict_int_int test_dict(dict_int_int& dict) {
}
i = 0;
for (auto entry: dict) {
for (auto entry : dict) {
TORCH_INTERNAL_ASSERT(order[i] == entry.first);
TORCH_INTERNAL_ASSERT(dict[order[i]] == entry.second);
TORCH_INTERNAL_ASSERT(entry.second == order[i] + 1);
@ -89,11 +90,11 @@ TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) {
TORCH_INTERNAL_ASSERT(dict.begin()->first == 1);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, testRefType) {
std::shared_ptr<int64_t> t;
using dict_references = ska_ordered::order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>;
using dict_references = ska_ordered::
order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>;
dict_references dict;
@ -108,7 +109,6 @@ TEST(OrderedPreservingDictTest, testRefType) {
TORCH_INTERNAL_ASSERT(ptr.use_count() == 1);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, DictCollisions) {
struct BadHash {
@ -171,254 +171,262 @@ TEST(OrderedPreservingDictTest, DictCollisions) {
}
}
// Tests taken from https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp
// Tests taken from
// https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_range_insert) {
// insert x values in vector, range insert x-15 values from vector to map, check values
const int nb_values = 1000;
std::vector<std::pair<int, int>> values;
for(int i = 0; i < nb_values; i++) {
// NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation)
values.push_back(std::make_pair(i, i+1));
}
// insert x values in vector, range insert x-15 values from vector to map,
// check values
const int nb_values = 1000;
std::vector<std::pair<int, int>> values;
for (int i = 0; i < nb_values; i++) {
// NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation)
values.push_back(std::make_pair(i, i + 1));
}
dict_int_int map = {{-1, 0}, {-2, 0}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert(values.begin() + 10, values.end() - 5);
dict_int_int map = {{-1, 0}, {-2, 0}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert(values.begin() + 10, values.end() - 5);
TORCH_INTERNAL_ASSERT(map.size(), 987);
TORCH_INTERNAL_ASSERT(map.size(), 987);
ASSERT_EQUAL_PRIM(map.at(-1), 0);
ASSERT_EQUAL_PRIM(map.at(-1), 0);
ASSERT_EQUAL_PRIM(map.at(-2), 0);
ASSERT_EQUAL_PRIM(map.at(-2), 0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for(int i = 10, j = 2; i < nb_values - 5; i++, j++) {
ASSERT_EQUAL_PRIM(map.at(i), i+1);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (int i = 10, j = 2; i < nb_values - 5; i++, j++) {
ASSERT_EQUAL_PRIM(map.at(i), i + 1);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_range_erase_all) {
// insert x values, delete all
const std::size_t nb_values = 1000;
dict_int_int map;
for (size_t i = 0; i < nb_values; ++i) {
map[i] = i + 1;
}
auto it = map.erase(map.begin(), map.end());
ASSERT_TRUE(it == map.end());
ASSERT_TRUE(map.empty());
// insert x values, delete all
const std::size_t nb_values = 1000;
dict_int_int map;
for (size_t i = 0; i < nb_values; ++i) {
map[i] = i + 1;
}
auto it = map.erase(map.begin(), map.end());
ASSERT_TRUE(it == map.end());
ASSERT_TRUE(map.empty());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_range_erase) {
// insert x values, delete all with iterators except 10 first and 780 last values
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>;
// insert x values, delete all with iterators except 10 first and 780 last
// values
using HMap =
ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>;
const std::size_t nb_values = 1000;
HMap map;
for (size_t i = 0; i < nb_values; ++i) {
map[c10::guts::to_string(i)] = i;
auto begin = map.begin();
for (size_t j = 0; j <= i; ++j, begin++) {
TORCH_INTERNAL_ASSERT(begin->second == j);
}
const std::size_t nb_values = 1000;
HMap map;
for (size_t i = 0; i < nb_values; ++i) {
map[c10::guts::to_string(i)] = i;
auto begin = map.begin();
for (size_t j = 0; j <= i; ++j, begin++) {
TORCH_INTERNAL_ASSERT(begin->second == j);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto it_first = std::next(map.begin(), 10);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto it_last = std::next(map.begin(), 220);
auto it = map.erase(it_first, it_last);
ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780);
ASSERT_EQUAL_PRIM(map.size(), 790);
ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790);
for (auto& val : map) {
ASSERT_EQUAL_PRIM(map.count(val.first), 1);
}
// Check order
it = map.begin();
for (std::size_t i = 0; i < nb_values; i++) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto it_first = std::next(map.begin(), 10);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto it_last = std::next(map.begin(), 220);
auto it = map.erase(it_first, it_last);
ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780);
ASSERT_EQUAL_PRIM(map.size(), 790);
ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790);
for(auto& val: map) {
ASSERT_EQUAL_PRIM(map.count(val.first), 1);
}
// Check order
it = map.begin();
for(std::size_t i = 0; i < nb_values; i++) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if(i >= 10 && i < 220) {
continue;
}
auto exp_it = std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
TORCH_INTERNAL_ASSERT(*it == exp_it);
++it;
if (i >= 10 && i < 220) {
continue;
}
auto exp_it =
std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
TORCH_INTERNAL_ASSERT(*it == exp_it);
++it;
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_move_constructor_empty) {
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move(std::move(map));
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move(
std::move(map));
// NOLINTNEXTLINE(bugprone-use-after-move)
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_move.empty());
// NOLINTNEXTLINE(bugprone-use-after-move)
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_move.empty());
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_move_operator_empty) {
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move;
map_move = (std::move(map));
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move;
map_move = (std::move(map));
// NOLINTNEXTLINE(bugprone-use-after-move)
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_move.empty());
// NOLINTNEXTLINE(bugprone-use-after-move)
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_move.empty());
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) {
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
using HMap =
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
HMap map_move(std::move(map));
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
HMap map_move(std::move(map));
ASSERT_EQUAL_PRIM(map_move.size(), 3);
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_EQUAL_PRIM(map.size(), 0);
ASSERT_EQUAL_PRIM(map_move.size(), 3);
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_EQUAL_PRIM(map.size(), 0);
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
TORCH_INTERNAL_ASSERT(
map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) {
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
using HMap =
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
HMap map_move = std::move(map);
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
HMap map_move = std::move(map);
ASSERT_EQUAL_PRIM(map_move.size(), 3);
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_EQUAL_PRIM(map.size(), 0);
ASSERT_EQUAL_PRIM(map_move.size(), 3);
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_EQUAL_PRIM(map.size(), 0);
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
TORCH_INTERNAL_ASSERT(
map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) {
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
using HMap =
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
const std::size_t nb_values = 100;
HMap map;
for (size_t i = 0; i < nb_values; ++i) {
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
}
const std::size_t nb_values = 100;
HMap map;
for (size_t i = 0; i < nb_values; ++i) {
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
}
HMap map_copy = map;
HMap map_copy2(map);
HMap map_copy3;
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
map_copy3 = map;
HMap map_copy = map;
HMap map_copy2(map);
HMap map_copy3;
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
TORCH_INTERNAL_ASSERT(map == map_copy);
map.clear();
map_copy3 = map;
TORCH_INTERNAL_ASSERT(map == map_copy);
map.clear();
TORCH_INTERNAL_ASSERT(map_copy == map_copy2);
TORCH_INTERNAL_ASSERT(map_copy == map_copy3);
TORCH_INTERNAL_ASSERT(map_copy == map_copy2);
TORCH_INTERNAL_ASSERT(map_copy == map_copy3);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_copy_constructor_empty) {
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map);
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map);
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_copy.empty());
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_copy.empty());
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_copy_operator_empty) {
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16);
map_copy = map;
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16);
map_copy = map;
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_copy.empty());
TORCH_INTERNAL_ASSERT(map.empty());
TORCH_INTERNAL_ASSERT(map_copy.empty());
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
}
/**
* at
*/
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_at) {
// insert x values, use at for known and unknown values.
const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
// insert x values, use at for known and unknown values.
const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>
map = {{0, 10}, {-2, 20}};
ASSERT_EQUAL_PRIM(map.at(0), 10);
ASSERT_EQUAL_PRIM(map.at(-2), 20);
bool thrown = false;
try {
map.at(1);
} catch (...) {
thrown = true;
}
ASSERT_TRUE(thrown);
ASSERT_EQUAL_PRIM(map.at(0), 10);
ASSERT_EQUAL_PRIM(map.at(-2), 20);
bool thrown = false;
try {
map.at(1);
} catch (...) {
thrown = true;
}
ASSERT_TRUE(thrown);
}
/**
* equal_range
*/
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_equal_range) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
{{0, 10}, {-2, 20}};
auto it_pair = map.equal_range(0);
ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1);
ASSERT_EQUAL_PRIM(it_pair.first->second, 10);
auto it_pair = map.equal_range(0);
ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1);
ASSERT_EQUAL_PRIM(it_pair.first->second, 10);
it_pair = map.equal_range(1);
TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second);
TORCH_INTERNAL_ASSERT(it_pair.first == map.end());
it_pair = map.equal_range(1);
TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second);
TORCH_INTERNAL_ASSERT(it_pair.first == map.end());
}
/**
* operator[]
*/
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_access_operator) {
// insert x values, use at for known and unknown values.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
// insert x values, use at for known and unknown values.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
{{0, 10}, {-2, 20}};
ASSERT_EQUAL_PRIM(map[0], 10);
ASSERT_EQUAL_PRIM(map[-2], 20);
ASSERT_EQUAL_PRIM(map[2], std::int64_t());
ASSERT_EQUAL_PRIM(map[0], 10);
ASSERT_EQUAL_PRIM(map[-2], 20);
ASSERT_EQUAL_PRIM(map[2], std::int64_t());
ASSERT_EQUAL_PRIM(map.size(), 3);
ASSERT_EQUAL_PRIM(map.size(), 3);
}
/**
@ -426,45 +434,72 @@ TEST(OrderedPreservingDictTest, test_access_operator) {
*/
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_swap) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{1, 10}, {8, 80}, {3, 30}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 = {{4, 40}, {5, 50}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
{{1, 10}, {8, 80}, {3, 30}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 =
{{4, 40}, {5, 50}};
using std::swap;
swap(map, map2);
using std::swap;
swap(map, map2);
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{4, 40}, {5, 50}}));
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}}));
TORCH_INTERNAL_ASSERT(
map ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{4, 40}, {5, 50}}));
TORCH_INTERNAL_ASSERT(
map2 ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{1, 10}, {8, 80}, {3, 30}}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert({6, 60});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map2.insert({4, 40});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert({6, 60});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map2.insert({4, 40});
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{4, 40}, {5, 50}, {6, 60}}));
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
TORCH_INTERNAL_ASSERT(
map ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{4, 40}, {5, 50}, {6, 60}}));
TORCH_INTERNAL_ASSERT(
map2 ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(OrderedPreservingDictTest, test_swap_empty) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{1, 10}, {8, 80}, {3, 30}};
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
{{1, 10}, {8, 80}, {3, 30}};
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2;
using std::swap;
swap(map, map2);
using std::swap;
swap(map, map2);
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{}));
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}}));
TORCH_INTERNAL_ASSERT(
map ==
(ska_ordered::
order_preserving_flat_hash_map<std::int64_t, std::int64_t>{}));
TORCH_INTERNAL_ASSERT(
map2 ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{1, 10}, {8, 80}, {3, 30}}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert({6, 60});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map2.insert({4, 40});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map.insert({6, 60});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
map2.insert({4, 40});
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{6, 60}}));
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
TORCH_INTERNAL_ASSERT(
map ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{6, 60}}));
TORCH_INTERNAL_ASSERT(
map2 ==
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
}
}
} // namespace

View File

@ -7,9 +7,9 @@ using c10::string_view;
namespace {
namespace testutils {
constexpr bool string_equal(const char* lhs, const char* rhs, size_t size) {
return (size == 0)
? true
: (*lhs != *rhs) ? false : string_equal(lhs + 1, rhs + 1, size - 1);
return (size == 0) ? true
: (*lhs != *rhs) ? false
: string_equal(lhs + 1, rhs + 1, size - 1);
}
static_assert(string_equal("hi", "hi", 2), "");
static_assert(string_equal("", "", 0), "");
@ -124,7 +124,8 @@ TEST(StringViewTest, testCopyAssignment) {
static_assert(5 == (string_view() = "hello").size(), "");
static_assert(
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
string_equal("hello", (string_view() = "hello").data(), 5), "");
string_equal("hello", (string_view() = "hello").data(), 5),
"");
}
#endif
const string_view hello = assign("hello");
@ -233,7 +234,7 @@ static_assert('o' == string_view("hello").back(), "");
namespace test_data {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
static_assert(string_equal("hello", string_view("hello").data(), 5), "");
}
} // namespace test_data
namespace test_size_length {
static_assert(0 == string_view("").size(), "");

View File

@ -1,7 +1,7 @@
#include <c10/util/tempfile.h>
#include <gtest/gtest.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/types.h>
#if !defined(_WIN32)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -8,7 +8,7 @@ namespace {
class TypeMetaTestFoo {};
class TypeMetaTestBar {};
}
} // namespace
CAFFE_KNOWN_TYPE(TypeMetaTestFoo);
CAFFE_KNOWN_TYPE(TypeMetaTestBar);
@ -73,7 +73,6 @@ TEST(TypeMetaTest, TypeMeta) {
EXPECT_NE(bar_meta.name().find("TypeMetaTestBar"), c10::string_view::npos);
}
class ClassAllowAssignment {
public:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
@ -92,7 +91,7 @@ class ClassNoAssignment {
ClassNoAssignment& operator=(const ClassNoAssignment& src) = delete;
int x;
};
}
} // namespace
CAFFE_KNOWN_TYPE(ClassAllowAssignment);
CAFFE_KNOWN_TYPE(ClassNoAssignment);
@ -135,5 +134,5 @@ TEST(TypeMetaTest, Float16IsNotUint16) {
EXPECT_NE(TypeMeta::Id<uint16_t>(), TypeMeta::Id<at::Half>());
}
} // namespace
} // namespace caffe2
} // namespace
} // namespace caffe2

View File

@ -1,18 +1,20 @@
/**
* This file is based on the std::array implementation of libstdc++ at
* https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html
*
* Changes:
* - isolate, i.e. remove dependencies on internal libstdc++ stuff
* - use c++17 behavior even in c++11 or c++14
* - remove std::swappable special case because that doesn't work with MSVC
* - constexpr more things
* - add some features like prepend/tail
*
* If using std::array at runtime, feel free to either keep using std::array or use this one - it doesn't really matter.
* For compile time computations, this one here is preferred because std::array in C++11
* misses some constexpr specifiers, forcing these methods to be called at runtime instead of compile time.
*/
* This file is based on the std::array implementation of libstdc++ at
* https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html
*
* Changes:
* - isolate, i.e. remove dependencies on internal libstdc++ stuff
* - use c++17 behavior even in c++11 or c++14
* - remove std::swappable special case because that doesn't work with MSVC
* - constexpr more things
* - add some features like prepend/tail
*
* If using std::array at runtime, feel free to either keep using std::array or
* use this one - it doesn't really matter. For compile time computations, this
* one here is preferred because std::array in C++11 misses some constexpr
* specifiers, forcing these methods to be called at runtime instead of compile
* time.
*/
// Copyright (C) 2007-2017 Free Software Foundation, Inc.
//
@ -38,16 +40,17 @@
#pragma once
#include <algorithm>
#include <c10/util/C++17.h>
#include <algorithm>
#include <stdexcept>
#include <string>
#include <utility>
namespace c10 { namespace guts {
namespace c10 {
namespace guts {
namespace detail {
template<typename _Tp, std::size_t _Nm>
template <typename _Tp, std::size_t _Nm>
struct __array_traits final {
using _Type = _Tp[_Nm];
@ -60,7 +63,7 @@ struct __array_traits final {
}
};
template<typename _Tp>
template <typename _Tp>
struct __array_traits<_Tp, 0> final {
struct _Type final {};
@ -76,11 +79,11 @@ struct __array_traits<_Tp, 0> final {
[[noreturn]] inline void __throw_out_of_range(std::string msg) {
throw std::out_of_range(std::move(msg));
}
}
} // namespace detail
template<typename _Tp, std::size_t _Nm>
template <typename _Tp, std::size_t _Nm>
class array final {
public:
public:
using value_type = _Tp;
using pointer = value_type*;
using const_pointer = const value_type*;
@ -93,76 +96,99 @@ public:
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
private:
private:
using _AT_Type = detail::__array_traits<_Tp, _Nm>;
public: // needs to be public member for aggregate initialization
public: // needs to be public member for aggregate initialization
typename _AT_Type::_Type _M_elems;
public:
public:
// No explicit construct/copy/destroy for aggregate type.
// DR 776.
constexpr void fill(const value_type& __u)
{ std::fill_n(begin(), size(), __u); }
constexpr void fill(const value_type& __u) {
std::fill_n(begin(), size(), __u);
}
constexpr void swap(array& __other)
{ std::swap_ranges(begin(), end(), __other.begin()); }
constexpr void swap(array& __other) {
std::swap_ranges(begin(), end(), __other.begin());
}
// Iterators.
constexpr iterator begin() noexcept
{ return iterator(data()); }
constexpr iterator begin() noexcept {
return iterator(data());
}
constexpr const_iterator begin() const noexcept
{ return const_iterator(data()); }
constexpr const_iterator begin() const noexcept {
return const_iterator(data());
}
constexpr iterator end() noexcept
{ return iterator(data() + _Nm); }
constexpr iterator end() noexcept {
return iterator(data() + _Nm);
}
constexpr const_iterator end() const noexcept
{ return const_iterator(data() + _Nm); }
constexpr const_iterator end() const noexcept {
return const_iterator(data() + _Nm);
}
constexpr reverse_iterator rbegin() noexcept
{ return reverse_iterator(end()); }
constexpr reverse_iterator rbegin() noexcept {
return reverse_iterator(end());
}
constexpr const_reverse_iterator rbegin() const noexcept
{ return const_reverse_iterator(end()); }
constexpr const_reverse_iterator rbegin() const noexcept {
return const_reverse_iterator(end());
}
constexpr reverse_iterator rend() noexcept
{ return reverse_iterator(begin()); }
constexpr reverse_iterator rend() noexcept {
return reverse_iterator(begin());
}
constexpr const_reverse_iterator rend() const noexcept
{ return const_reverse_iterator(begin()); }
constexpr const_reverse_iterator rend() const noexcept {
return const_reverse_iterator(begin());
}
constexpr const_iterator cbegin() const noexcept
{ return const_iterator(data()); }
constexpr const_iterator cbegin() const noexcept {
return const_iterator(data());
}
constexpr const_iterator cend() const noexcept
{ return const_iterator(data() + _Nm); }
constexpr const_iterator cend() const noexcept {
return const_iterator(data() + _Nm);
}
constexpr const_reverse_iterator crbegin() const noexcept
{ return const_reverse_iterator(end()); }
constexpr const_reverse_iterator crbegin() const noexcept {
return const_reverse_iterator(end());
}
constexpr const_reverse_iterator crend() const noexcept
{ return const_reverse_iterator(begin()); }
constexpr const_reverse_iterator crend() const noexcept {
return const_reverse_iterator(begin());
}
// Capacity.
constexpr size_type size() const noexcept { return _Nm; }
constexpr size_type size() const noexcept {
return _Nm;
}
constexpr size_type max_size() const noexcept { return _Nm; }
constexpr size_type max_size() const noexcept {
return _Nm;
}
constexpr bool empty() const noexcept { return size() == 0; }
constexpr bool empty() const noexcept {
return size() == 0;
}
// Element access.
constexpr reference operator[](size_type __n) noexcept
{ return _AT_Type::_S_ref(_M_elems, __n); }
constexpr reference operator[](size_type __n) noexcept {
return _AT_Type::_S_ref(_M_elems, __n);
}
constexpr const_reference operator[](size_type __n) const noexcept
{ return _AT_Type::_S_ref(_M_elems, __n); }
constexpr const_reference operator[](size_type __n) const noexcept {
return _AT_Type::_S_ref(_M_elems, __n);
}
constexpr reference at(size_type __n) {
if (__n >= _Nm) {
detail::__throw_out_of_range(std::string() +
"array::at: __n (which is " + to_string(__n) + ") " +
detail::__throw_out_of_range(
std::string() + "array::at: __n (which is " + to_string(__n) + ") " +
">= _Nm (which is " + to_string(_Nm) + ")");
}
return _AT_Type::_S_ref(_M_elems, __n);
@ -171,101 +197,133 @@ public:
constexpr const_reference at(size_type __n) const {
// Result of conditional expression must be an lvalue so use
// boolean ? lvalue : (throw-expr, lvalue)
return __n < _Nm ? _AT_Type::_S_ref(_M_elems, __n)
: (detail::__throw_out_of_range(std::string() +
"array::at: __n (which is " + to_string(__n) + ") " +
">= _Nm (which is " + to_string(_Nm) + ")"),
_AT_Type::_S_ref(_M_elems, 0));
}
constexpr reference front() noexcept
{ return *begin(); }
constexpr const_reference front() const noexcept
{ return _AT_Type::_S_ref(_M_elems, 0); }
constexpr reference back() noexcept
{ return _Nm ? *(end() - 1) : *end(); }
constexpr const_reference back() const noexcept
{
return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1)
: _AT_Type::_S_ref(_M_elems, 0);
return __n < _Nm
? _AT_Type::_S_ref(_M_elems, __n)
: (detail::__throw_out_of_range(
std::string() + "array::at: __n (which is " + to_string(__n) +
") " + ">= _Nm (which is " + to_string(_Nm) + ")"),
_AT_Type::_S_ref(_M_elems, 0));
}
constexpr pointer data() noexcept
{ return _AT_Type::_S_ptr(_M_elems); }
constexpr reference front() noexcept {
return *begin();
}
constexpr const_pointer data() const noexcept
{ return _AT_Type::_S_ptr(_M_elems); }
constexpr const_reference front() const noexcept {
return _AT_Type::_S_ref(_M_elems, 0);
}
constexpr reference back() noexcept {
return _Nm ? *(end() - 1) : *end();
}
constexpr const_reference back() const noexcept {
return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1)
: _AT_Type::_S_ref(_M_elems, 0);
}
constexpr pointer data() noexcept {
return _AT_Type::_S_ptr(_M_elems);
}
constexpr const_pointer data() const noexcept {
return _AT_Type::_S_ptr(_M_elems);
}
};
#if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201606
template<typename _Tp, typename... _Up>
array(_Tp, _Up...) ->
array<std::enable_if_t<(std::is_same<_Tp, _Up>::value && ...), _Tp>, 1 + sizeof...(_Up)>;
template <typename _Tp, typename... _Up>
array(_Tp, _Up...) -> array<
std::enable_if_t<(std::is_same<_Tp, _Up>::value && ...), _Tp>,
1 + sizeof...(_Up)>;
#endif
// Array comparisons.
namespace detail {
template<class T, size_t N>
constexpr inline bool array_equals_(const array<T, N>& lhs, const array<T, N>& rhs, size_t current_index) {
template <class T, size_t N>
constexpr inline bool array_equals_(
const array<T, N>& lhs,
const array<T, N>& rhs,
size_t current_index) {
return (current_index == N)
? true
: (lhs.at(current_index) == rhs.at(current_index) && array_equals_(lhs, rhs, current_index + 1));
? true
: (lhs.at(current_index) == rhs.at(current_index) &&
array_equals_(lhs, rhs, current_index + 1));
}
template<class T, size_t N>
constexpr inline bool array_less_(const array<T, N>& lhs, const array<T, N>& rhs, size_t current_index) {
template <class T, size_t N>
constexpr inline bool array_less_(
const array<T, N>& lhs,
const array<T, N>& rhs,
size_t current_index) {
return (current_index == N)
? false
: (lhs.at(current_index) < rhs.at(current_index) || array_less_(lhs, rhs, current_index + 1));
? false
: (lhs.at(current_index) < rhs.at(current_index) ||
array_less_(lhs, rhs, current_index + 1));
}
} // namespace detail
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator==(
const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two) {
return detail::array_equals_(__one, __two, 0);
}
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator==(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
{ return detail::array_equals_(__one, __two, 0); }
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator!=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
{ return !(__one == __two); }
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator!=(
const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two) {
return !(__one == __two);
}
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator<(const array<_Tp, _Nm>& __a, const array<_Tp, _Nm>& __b)
{ return detail::array_less_(__a, __b, 0); }
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator<(
const array<_Tp, _Nm>& __a,
const array<_Tp, _Nm>& __b) {
return detail::array_less_(__a, __b, 0);
}
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator>(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
{ return __two < __one; }
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator>(
const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two) {
return __two < __one;
}
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator<=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
{ return !(__one > __two); }
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator<=(
const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two) {
return !(__one > __two);
}
template<typename _Tp, std::size_t _Nm>
constexpr inline bool operator>=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
{ return !(__one < __two); }
template <typename _Tp, std::size_t _Nm>
constexpr inline bool operator>=(
const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two) {
return !(__one < __two);
}
// Specialized algorithms.
template<typename _Tp, std::size_t _Nm>
inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept(noexcept(__one.swap(__two)))
{ __one.swap(__two); }
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept {
static_assert(_Int < _Nm, "array index is within bounds");
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
template <typename _Tp, std::size_t _Nm>
inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept(
noexcept(__one.swap(__two))) {
__one.swap(__two);
}
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept
{
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept {
static_assert(_Int < _Nm, "array index is within bounds");
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
}
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept {
static_assert(_Int < _Nm, "array index is within bounds");
return std::move(get<_Int>(__arr));
}
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept
{
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept {
static_assert(_Int < _Nm, "array index is within bounds");
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
}
@ -278,27 +336,34 @@ constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept
* prepend(2, {3, 4}) == {2, 3, 4}
*/
namespace detail {
template<class T, size_t N, size_t... INDEX>
constexpr inline array<T, N-1> tail_(const array<T, N>& arg, std::index_sequence<INDEX...>) {
static_assert(sizeof...(INDEX) == N-1, "invariant");
return {{get<INDEX+1>(arg)...}};
template <class T, size_t N, size_t... INDEX>
constexpr inline array<T, N - 1> tail_(
const array<T, N>& arg,
std::index_sequence<INDEX...>) {
static_assert(sizeof...(INDEX) == N - 1, "invariant");
return {{get<INDEX + 1>(arg)...}};
}
}
template<class T, size_t N>
constexpr inline array<T, N-1> tail(const array<T, N>& arg) {
static_assert(N > 0, "Can only call tail() on an array with at least one element");
return detail::tail_(arg, std::make_index_sequence<N-1>());
} // namespace detail
template <class T, size_t N>
constexpr inline array<T, N - 1> tail(const array<T, N>& arg) {
static_assert(
N > 0, "Can only call tail() on an array with at least one element");
return detail::tail_(arg, std::make_index_sequence<N - 1>());
}
namespace detail {
template<class T, size_t N, size_t... INDEX>
constexpr inline array<T, N+1> prepend_(T&& head, const array<T, N>& tail, std::index_sequence<INDEX...>) {
template <class T, size_t N, size_t... INDEX>
constexpr inline array<T, N + 1> prepend_(
T&& head,
const array<T, N>& tail,
std::index_sequence<INDEX...>) {
return {{std::forward<T>(head), get<INDEX>(tail)...}};
}
}
template<class T, size_t N>
constexpr inline array<T, N+1> prepend(T&& head, const array<T, N>& tail) {
return detail::prepend_(std::forward<T>(head), tail, std::make_index_sequence<N>());
} // namespace detail
template <class T, size_t N>
constexpr inline array<T, N + 1> prepend(T&& head, const array<T, N>& tail) {
return detail::prepend_(
std::forward<T>(head), tail, std::make_index_sequence<N>());
}
/**
@ -309,15 +374,18 @@ constexpr inline array<T, N+1> prepend(T&& head, const array<T, N>& tail) {
*/
namespace detail {
template<class T, size_t N, size_t... INDEX>
constexpr array<T, N> to_array_(const T (&arr)[N], std::index_sequence<INDEX...>) {
template <class T, size_t N, size_t... INDEX>
constexpr array<T, N> to_array_(
const T (&arr)[N],
std::index_sequence<INDEX...>) {
return {{arr[INDEX]...}};
}
}
} // namespace detail
template<class T, size_t N>
template <class T, size_t N>
constexpr array<T, N> to_array(const T (&arr)[N]) {
return detail::to_array_(arr, std::make_index_sequence<N>());
}
}}
} // namespace guts
} // namespace c10

View File

@ -15,10 +15,10 @@
#pragma once
#include <c10/util/SmallVector.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <array>
#include <iterator>
@ -81,12 +81,15 @@ class ArrayRef final {
: Data(Vec.data()), Length(Vec.size()) {}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for std::vector<bool>,
// because ArrayRef can't work on a std::vector<bool> bitfield.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(!std::is_same<T, bool>::value, "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
static_assert(
!std::is_same<T, bool>::value,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
@ -100,7 +103,9 @@ class ArrayRef final {
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr) : std::begin(Vec)),
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
/// @}
@ -146,7 +151,8 @@ class ArrayRef final {
/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
TORCH_CHECK(!empty(), "ArrayRef: attempted to access front() of empty list");
TORCH_CHECK(
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
}
@ -162,7 +168,8 @@ class ArrayRef final {
}
/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M) const {
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
const {
TORCH_CHECK(
N + M <= size(),
"ArrayRef: invalid slice, N = ",
@ -224,10 +231,10 @@ class ArrayRef final {
};
template <typename T>
std::ostream& operator<<(std::ostream & out, ArrayRef<T> list) {
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;
out << "[";
for(auto e : list) {
for (auto e : list) {
if (i++ > 0)
out << ", ";
out << e;

Some files were not shown because too many files have changed in this diff Show More