diff --git a/c10/benchmark/intrusive_ptr_benchmark.cpp b/c10/benchmark/intrusive_ptr_benchmark.cpp index 4336c4e1094..75237ecdedd 100644 --- a/c10/benchmark/intrusive_ptr_benchmark.cpp +++ b/c10/benchmark/intrusive_ptr_benchmark.cpp @@ -18,7 +18,6 @@ class Foo : public intrusive_ptr_target { int param; }; - class Bar : public std::enable_shared_from_this { public: Bar(int param_) : param(param_) {} @@ -48,7 +47,7 @@ BENCHMARK(BM_SharedPtrCtorDtor); static void BM_IntrusivePtrArray(benchmark::State& state) { intrusive_ptr var = make_intrusive(0); const size_t kLength = state.range(0); - std::vector > vararray(kLength); + std::vector> vararray(kLength); while (state.KeepRunning()) { for (const auto i : c10::irange(kLength)) { vararray[i] = var; @@ -64,7 +63,7 @@ BENCHMARK(BM_IntrusivePtrArray)->RangeMultiplier(2)->Range(16, 4096); static void BM_SharedPtrArray(benchmark::State& state) { std::shared_ptr var = std::make_shared(0); const size_t kLength = state.range(0); - std::vector > vararray(kLength); + std::vector> vararray(kLength); while (state.KeepRunning()) { for (const auto i : c10::irange(kLength)) { vararray[i] = var; @@ -78,5 +77,4 @@ static void BM_SharedPtrArray(benchmark::State& state) { BENCHMARK(BM_SharedPtrArray)->RangeMultiplier(2)->Range(16, 4096); } // namespace - BENCHMARK_MAIN(); diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index 406fa1a0cc8..8ee23dc7374 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -12,10 +12,11 @@ at::DataPtr InefficientStdFunctionContext::makeDataPtr( void* ptr, const std::function& deleter, Device device) { - return {ptr, - new InefficientStdFunctionContext({ptr, deleter}), - &deleteInefficientStdFunctionContext, - device}; + return { + ptr, + new InefficientStdFunctionContext({ptr, deleter}), + &deleteInefficientStdFunctionContext, + device}; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index 1554f8dbd6a..338c70980fa 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -94,7 +94,9 @@ class C10_API DataPtr { * be; be sure to read the source code of the Allocator * in question to confirm this. */ - C10_NODISCARD bool compare_exchange_deleter(DeleterFnPtr expected_deleter, DeleterFnPtr new_deleter) { + C10_NODISCARD bool compare_exchange_deleter( + DeleterFnPtr expected_deleter, + DeleterFnPtr new_deleter) { return ptr_.compare_exchange_deleter(expected_deleter, new_deleter); } Device device() const { @@ -215,8 +217,8 @@ struct AllocatorRegisterer { } }; -#define REGISTER_ALLOCATOR(t, f) \ - namespace { \ +#define REGISTER_ALLOCATOR(t, f) \ + namespace { \ static AllocatorRegisterer g_allocator_d(f); \ } @@ -227,12 +229,18 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { virtual ~MemoryReportingInfoBase() {} // Negative alloc_size corresponds to freeing of the memory - virtual void reportMemoryUsage(void* ptr, int64_t alloc_size, Device device) = 0; + virtual void reportMemoryUsage( + void* ptr, + int64_t alloc_size, + Device device) = 0; virtual bool memoryProfilingEnabled() const = 0; }; C10_API bool memoryProfilingEnabled(); -C10_API void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device); +C10_API void reportMemoryUsageToProfiler( + void* ptr, + int64_t alloc_size, + Device device); } // namespace c10 diff --git a/c10/core/Backend.h b/c10/core/Backend.h index d4228200d1e..4f8274123b7 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -253,7 +253,7 @@ static inline bool isSparse(Backend b) { } static inline bool isSparseCsr(Backend b) { - switch(b) { + switch (b) { case Backend::SparseCsrCPU: case Backend::SparseCsrCUDA: return true; diff --git a/c10/core/CompileTimeFunctionPointer.h b/c10/core/CompileTimeFunctionPointer.h index d9424329469..6314e3e7708 100644 --- a/c10/core/CompileTimeFunctionPointer.h +++ b/c10/core/CompileTimeFunctionPointer.h @@ -29,9 +29,11 @@ namespace c10 { * } * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); */ -template +template struct CompileTimeFunctionPointer final { - static_assert(guts::is_function_type::value, "TORCH_FN can only wrap function types."); + static_assert( + guts::is_function_type::value, + "TORCH_FN can only wrap function types."); using FuncType = FuncType_; static constexpr FuncType* func_ptr() { @@ -39,11 +41,16 @@ struct CompileTimeFunctionPointer final { } }; -template struct is_compile_time_function_pointer : std::false_type {}; -template -struct is_compile_time_function_pointer> : std::true_type {}; +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; -} +} // namespace c10 -#define TORCH_FN_TYPE(func) ::c10::CompileTimeFunctionPointer>, func> +#define TORCH_FN_TYPE(func) \ + ::c10::CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> #define TORCH_FN(func) TORCH_FN_TYPE(func)() diff --git a/c10/core/CopyBytes.cpp b/c10/core/CopyBytes.cpp index 520391c6852..b652c35f4ca 100644 --- a/c10/core/CopyBytes.cpp +++ b/c10/core/CopyBytes.cpp @@ -47,4 +47,4 @@ void CopyBytes( ptr(nbytes, src, src_device, dst, dst_device); } -} +} // namespace c10 diff --git a/c10/core/DefaultDtype.cpp b/c10/core/DefaultDtype.cpp index 3cfbf80e82f..ad76a551250 100644 --- a/c10/core/DefaultDtype.cpp +++ b/c10/core/DefaultDtype.cpp @@ -1,5 +1,5 @@ -#include #include +#include namespace c10 { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -7,7 +7,8 @@ static auto default_dtype = caffe2::TypeMeta::Make(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static auto default_dtype_as_scalartype = default_dtype.toScalarType(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -static auto default_complex_dtype = caffe2::TypeMeta::Make>(); +static auto default_complex_dtype = + caffe2::TypeMeta::Make>(); void set_default_dtype(caffe2::TypeMeta dtype) { default_dtype = dtype; diff --git a/c10/core/DefaultDtype.h b/c10/core/DefaultDtype.h index d0a17474bda..f2f95c0da15 100644 --- a/c10/core/DefaultDtype.h +++ b/c10/core/DefaultDtype.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include namespace caffe2 { class TypeMeta; diff --git a/c10/core/DefaultTensorOptions.h b/c10/core/DefaultTensorOptions.h index 7bf96525c11..36af2657847 100644 --- a/c10/core/DefaultTensorOptions.h +++ b/c10/core/DefaultTensorOptions.h @@ -13,19 +13,27 @@ struct TensorOptions; struct DefaultTensorOptions { DefaultTensorOptions() = default; - caffe2::TypeMeta dtype() const noexcept { return dtype_; } - Device device() const noexcept { return device_; } - Layout layout() const noexcept { return layout_; } - bool requires_grad() const noexcept { return requires_grad_; } + caffe2::TypeMeta dtype() const noexcept { + return dtype_; + } + Device device() const noexcept { + return device_; + } + Layout layout() const noexcept { + return layout_; + } + bool requires_grad() const noexcept { + return requires_grad_; + } // Defined in TensorOptions.h inline DefaultTensorOptions& merge(const TensorOptions& options); private: caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit - Device device_ = at::kCPU; // 32-bit - Layout layout_ = at::kStrided; // 8-bit - bool requires_grad_ = false; // 8-bit + Device device_ = at::kCPU; // 32-bit + Layout layout_ = at::kStrided; // 8-bit + bool requires_grad_ = false; // 8-bit }; inline const DefaultTensorOptions& getDefaultTensorOptions() { diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 4ce305aee62..e537be23d32 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -6,25 +6,24 @@ #include #include #include +#include #include #include #include -#include // Check if compiler has working std::regex implementation // // Test below is adapted from https://stackoverflow.com/a/41186162 #if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L - // Compiler has working regex. MSVC has erroneous __cplusplus. +// Compiler has working regex. MSVC has erroneous __cplusplus. #elif __cplusplus >= 201103L && \ (!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \ - (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \ - defined(_GLIBCXX_REGEX_STATE_LIMIT) || \ - (defined(_GLIBCXX_RELEASE) && \ - _GLIBCXX_RELEASE > 4))) - // Compiler has working regex. + (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \ + defined(_GLIBCXX_REGEX_STATE_LIMIT) || \ + (defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4))) +// Compiler has working regex. #else - static_assert(false, "Compiler does not have proper regex support."); +static_assert(false, "Compiler does not have proper regex support."); #endif namespace c10 { @@ -58,7 +57,8 @@ DeviceType parse_type(const std::string& device_string) { if (device != types.end()) { return device->second; } - TORCH_CHECK(false, + TORCH_CHECK( + false, "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ", device_string); } @@ -71,16 +71,22 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) { static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?"); std::smatch match; TORCH_CHECK( - std::regex_match(device_string, match, regex), - "Invalid device string: '", device_string, "'"); + std::regex_match(device_string, match, regex), + "Invalid device string: '", + device_string, + "'"); type_ = parse_type(match[1].str()); if (match[2].matched) { try { index_ = c10::stoi(match[2].str()); - } catch (const std::exception &) { - TORCH_CHECK(false, - "Could not parse device index '", match[2].str(), - "' in device string '", device_string, "'"); + } catch (const std::exception&) { + TORCH_CHECK( + false, + "Could not parse device index '", + match[2].str(), + "' in device string '", + device_string, + "'"); } } validate(); diff --git a/c10/core/Device.h b/c10/core/Device.h index 3ce6cd0ef88..3f904b389e4 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -107,16 +107,18 @@ struct C10_API Device final { // performance in micro-benchmarks. // This is safe to do, because backends that use the DeviceIndex // have a later check when we actually try to switch to that device. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0, - "Device index must be -1 or non-negative, got ", (int)index_); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0, - "CPU device index must be -1 or zero, got ", (int)index_); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index_ == -1 || index_ >= 0, + "Device index must be -1 or non-negative, got ", + (int)index_); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", + (int)index_); } }; -C10_API std::ostream& operator<<( - std::ostream& stream, - const Device& device); +C10_API std::ostream& operator<<(std::ostream& stream, const Device& device); } // namespace c10 @@ -136,10 +138,11 @@ struct hash { // half of the resulting integer. // // Technically, by C/C++ integer promotion rules, we only need one of the - // uint32_t casts to the result type, but we put in both for explicitness's sake. - uint32_t bits = - static_cast(static_cast(d.type())) << 16 - | static_cast(static_cast(d.index())); + // uint32_t casts to the result type, but we put in both for explicitness's + // sake. + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); return std::hash{}(bits); } }; diff --git a/c10/core/DeviceGuard.h b/c10/core/DeviceGuard.h index 852d6366ebd..ed627f03171 100644 --- a/c10/core/DeviceGuard.h +++ b/c10/core/DeviceGuard.h @@ -17,7 +17,7 @@ namespace c10 { /// want to setup a guard (i.e., are looking for the moral equivalent /// of optional), see OptionalDeviceGuard. class DeviceGuard { -public: + public: /// No default constructor; see Note [Omitted default constructor from RAII] explicit DeviceGuard() = delete; @@ -25,7 +25,10 @@ public: explicit DeviceGuard(Device device) : guard_(device) {} /// This constructor is for testing only. - explicit DeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {} + explicit DeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} /// Copy is disallowed DeviceGuard(const DeviceGuard&) = delete; @@ -48,7 +51,9 @@ public: } /// This method is for testing only. - void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) { + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { guard_.reset_device(device, impl); } @@ -69,7 +74,7 @@ public: return guard_.current_device(); } -private: + private: impl::InlineDeviceGuard guard_; }; @@ -79,8 +84,8 @@ private: * Morally, a OptionalDeviceGuard is equivalent to optional, but * with extra constructors and methods as appropriate. * - * Besides its obvious use (optionally applying a DeviceGuard), OptionalDeviceGuard - * is often also used for the following idiom: + * Besides its obvious use (optionally applying a DeviceGuard), + * OptionalDeviceGuard is often also used for the following idiom: * * OptionalDeviceGuard g; * for (const auto& t : tensors) { @@ -117,7 +122,7 @@ private: * DeviceGuard will still reset the device to original_device_. */ class OptionalDeviceGuard { -public: + public: /// Create an uninitialized guard. Set the guard later using reset_device. explicit OptionalDeviceGuard() : guard_() {} @@ -129,7 +134,10 @@ public: explicit OptionalDeviceGuard(optional device) : guard_(device) {} /// Constructor for testing only. - explicit OptionalDeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {} + explicit OptionalDeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} /// Copy is disallowed OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; @@ -149,7 +157,9 @@ public: } /// For testing only - void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) { + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { guard_.reset_device(device, impl); } @@ -164,7 +174,7 @@ public: return guard_.current_device(); } -private: + private: impl::InlineOptionalDeviceGuard guard_; }; @@ -173,7 +183,8 @@ private: // Design note: in principle, we could avoid these wrappers using: // // using DeviceGuard = impl::InlineDeviceGuard; -// using OptionalDeviceGuard = impl::InlineOptionalDeviceGuard; +// using OptionalDeviceGuard = +// impl::InlineOptionalDeviceGuard; // // But the error messages are worse, and our users can't just look at the // header file to find out what's going on. Furthermore, for specializations diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 3cdb53e9928..6a38cdb87d9 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -38,7 +38,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { case DeviceType::Meta: return lower_case ? "meta" : "META"; default: - TORCH_CHECK(false, + TORCH_CHECK( + false, "Unknown device: ", static_cast(d), ". If you have recently updated the caffe2.proto file to add a new " diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 8ba366abda4..df80aff6c99 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -7,8 +7,8 @@ #include -#include #include +#include namespace c10 { @@ -51,7 +51,8 @@ constexpr DeviceType kXPU = DeviceType::XPU; constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); -static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16, +static_assert( + COMPILE_TIME_MAX_DEVICE_TYPES <= 16, "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " "for this constant to reflect the actual number of DeviceTypes we support " "in PyTorch; it's important that this number is not too large as we " @@ -61,9 +62,7 @@ static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16, "types registration, please be aware that you are affecting code that " "this number is small. Try auditing uses of this constant."); -C10_API std::string DeviceTypeName( - DeviceType d, - bool lower_case = false); +C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); C10_API bool isValidDeviceType(DeviceType d); @@ -72,7 +71,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); } // namespace c10 namespace std { -template <> struct hash { +template <> +struct hash { std::size_t operator()(c10::DeviceType k) const { return std::hash()(static_cast(k)); } @@ -80,5 +80,5 @@ template <> struct hash { } // namespace std namespace torch { - using c10::DeviceType; +using c10::DeviceType; } diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 75b4c36d40b..a82edf1e26e 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -1,11 +1,11 @@ #pragma once -#include -#include -#include #include #include #include +#include +#include +#include namespace c10 { @@ -58,7 +58,7 @@ enum class DispatchKey : uint8_t { HIP, // NB: I think this is not actually used, due to Note [Masquerading as // CUDA] FPGA, // Xilinx support lives out of tree at - // https://gitlab.com/pytorch-complex/vitis_kernels + // https://gitlab.com/pytorch-complex/vitis_kernels MSNPU, // unused externally, but tested at // test/cpp_extensions/msnpu_extension.cpp XLA, // lives out of tree at https://github.com/pytorch/xla @@ -177,7 +177,8 @@ enum class DispatchKey : uint8_t { // But this work is currently blocked since it adds an extra dispatch // for all ops and it's non-trivial overhead at model level(a few percents). // Thus our current approach takes advantage of the fact every kernel go - // through VariableType kernel first and pulls the `at::AutoDispatchBelowInplaceOrView` guard of functional ops + // through VariableType kernel first and pulls the + // `at::AutoDispatchBelowInplaceOrView` guard of functional ops // up to the `VariableType` kernel. Thus we only add the extra dispatch // to view/inplace ops to minimize its perf impact to real models. InplaceOrView, @@ -213,7 +214,8 @@ enum class DispatchKey : uint8_t { AutogradXLA, AutogradXPU, AutogradMLC, - AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor + AutogradNestedTensor, // lives out of tree at + // https://github.com/pytorch/nestedtensor // Here are some reserved pre-autograd keys for user-defined backends, see // Note [Private use DispatchKey] AutogradPrivateUse1, @@ -224,7 +226,7 @@ enum class DispatchKey : uint8_t { // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed // and inputs are saved for backward in the post-autocast type. - //AutocastCPU, + // AutocastCPU, AutocastCUDA, // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // @@ -278,9 +280,10 @@ enum class DispatchKey : uint8_t { // See Note [Alias Dispatch Key : Autograd] Autograd, - CompositeImplicitAutograd, // registered at build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp + CompositeImplicitAutograd, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp CompositeExplicitAutograd, // registered at - // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp // Define an alias key to represent end of alias dispatch keys. // If you add new alias keys after Autograd, please also update it here. @@ -316,15 +319,15 @@ enum class DispatchKey : uint8_t { // We provide two classes of private user tensor id: regular DispatchKeys // and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend" // DispatchKeys; if you were adding support for a new type of accelerator, you -// would use a backend DispatchKey, and ideally automatically reuse AutogradOther -// definitions already defined in PyTorch. AutogradPrivateUse DispatchKeys serve -// as "wrapper" DispatchKeys: they are only necessary for tensors that compose -// multiple internal tensors, and for cases when the built-in autograd formulas -// for operators are not appropriate. +// would use a backend DispatchKey, and ideally automatically reuse +// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse +// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for +// tensors that compose multiple internal tensors, and for cases when the +// built-in autograd formulas for operators are not appropriate. static_assert( - static_cast(DispatchKey::NumDispatchKeys) < 64, - "DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries"); + static_cast(DispatchKey::NumDispatchKeys) < 64, + "DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries"); C10_API const char* toString(DispatchKey); C10_API std::ostream& operator<<(std::ostream&, DispatchKey); @@ -345,10 +348,10 @@ inline bool isAliasDispatchKey(DispatchKey k) { } // namespace c10 namespace torch { - // Expose the constant, but not the TYPE (DispatchKey is an implementation - // detail!) - using c10::kAutograd; -} +// Expose the constant, but not the TYPE (DispatchKey is an implementation +// detail!) +using c10::kAutograd; +} // namespace torch // NB: You really shouldn't use this instance; this enum is guaranteed // to be pretty small so a regular array should be acceptable. @@ -362,4 +365,4 @@ struct hash { return static_cast(x); } }; -} +} // namespace std diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 0ea9c1cfdd0..0799ebaf5a6 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -3,9 +3,10 @@ namespace c10 { // backend_dispatch_keyset should include all runtime backend keys. -// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset -// NestedTensor has been explicitly removed due to incompatibility with some -// kernels, such as structured kernels, that use the DefaultBackend key. +// Alias key DispatchKey::CompositeExplicitAutograd maps to +// backend_dispatch_keyset NestedTensor has been explicitly removed due to +// incompatibility with some kernels, such as structured kernels, that use the +// DefaultBackend key. constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKeySet({ DispatchKey::CPU, @@ -23,9 +24,11 @@ bool isBackendDispatchKey(DispatchKey t) { return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t); } -// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset -// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset. -constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset; +// math_dispatch_keyset contains all keys in backend_dispatch_keyset and +// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd +// maps to math_dispatch_keyset. +constexpr DispatchKeySet math_dispatch_keyset = + backend_dispatch_keyset | autograd_dispatch_keyset; DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); @@ -41,8 +44,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { } } -// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys. -// for a non-autograd key, return the empty keyset. +// for a given autograd key, return the (guaranteed nonempty) set of associated +// backend keys. for a non-autograd key, return the empty keyset. DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { switch (t) { case DispatchKey::AutogradCPU: @@ -72,7 +75,7 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) { switch (t) { - //case DispatchKey::CPU: + // case DispatchKey::CPU: // return DispatchKeySet(DispatchKey::AutocastCPU); case DispatchKey::CUDA: return DispatchKeySet(DispatchKey::AutocastCUDA); @@ -82,8 +85,8 @@ DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) { } DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) { - return DispatchKeySet({ - DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)}); + return DispatchKeySet( + {DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)}); } bool isIncludedInAlias(DispatchKey k, DispatchKey alias) { @@ -116,4 +119,4 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) { return os; } -} +} // namespace c10 diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index ec80e5d61e2..e78ce18aa76 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -1,9 +1,9 @@ #pragma once #include -#include #include #include +#include #include namespace c10 { @@ -33,30 +33,29 @@ namespace c10 { // // An undefined tensor is one with an empty tensor type set. class DispatchKeySet final { -public: + public: enum Full { FULL }; enum FullAfter { FULL_AFTER }; enum Raw { RAW }; // NB: default constructor representation as zero is MANDATORY as // use of DispatchKeySet in TLS requires this. - constexpr DispatchKeySet() - : repr_(0) {} + constexpr DispatchKeySet() : repr_(0) {} constexpr DispatchKeySet(Full) - : repr_(std::numeric_limits::max()) {} + : repr_(std::numeric_limits::max()) {} constexpr DispatchKeySet(FullAfter, DispatchKey t) - // LSB after t are OK, but not t itself. - : repr_((1ULL << (static_cast(t) - 1)) - 1) {} + // LSB after t are OK, but not t itself. + : repr_((1ULL << (static_cast(t) - 1)) - 1) {} // Public version of DispatchKeySet(uint64_t) API; external users // must be explicit when they do this! - constexpr DispatchKeySet(Raw, uint64_t x) - : repr_(x) {} + constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} explicit constexpr DispatchKeySet(DispatchKey t) - : repr_(t == DispatchKey::Undefined - ? 0 - : 1ULL << (static_cast(t) - 1)) {} + : repr_( + t == DispatchKey::Undefined + ? 0 + : 1ULL << (static_cast(t) - 1)) {} explicit constexpr DispatchKeySet(std::initializer_list ks) - : repr_(0) { + : repr_(0) { for (auto k : ks) { repr_ |= DispatchKeySet(k).repr_; } @@ -105,7 +104,9 @@ public: bool empty() const { return repr_ == 0; } - uint64_t raw_repr() { return repr_; } + uint64_t raw_repr() { + return repr_; + } // Return the type id in this set with the highest priority (i.e., // is the largest in the DispatchKey enum). Intuitively, this // type id is the one that should handle dispatch (assuming there @@ -119,18 +120,20 @@ public: } DispatchKey highestPriorityBackendTypeId() const { - return (*this & ((1ULL << static_cast(DispatchKey::EndOfBackendKeys)) - 1)) - .highestPriorityTypeId(); + return (*this & + ((1ULL << static_cast(DispatchKey::EndOfBackendKeys)) - 1)) + .highestPriorityTypeId(); } -private: + + private: constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {} uint64_t repr_ = 0; -public: + public: // STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the // set. The iterator is only invalidated by the destruction of the underlying - // DispatchKeySet as the iterator stores a pointer to the raw representation of - // the DispatchKeySet. + // DispatchKeySet as the iterator stores a pointer to the raw representation + // of the DispatchKeySet. class iterator { public: using self_type = iterator; @@ -138,13 +141,15 @@ public: using value_type = DispatchKey; using difference_type = ptrdiff_t; - explicit iterator(const uint64_t *data_ptr, uint8_t i=0) : data_ptr_(data_ptr), i_(i) { + explicit iterator(const uint64_t* data_ptr, uint8_t i = 0) + : data_ptr_(data_ptr), i_(i) { // Go to the first key in the set ++(*this); } self_type& operator++() { - TORCH_INTERNAL_ASSERT(i_ <= static_cast(DispatchKey::NumDispatchKeys)); + TORCH_INTERNAL_ASSERT( + i_ <= static_cast(DispatchKey::NumDispatchKeys)); // Create a masked version of the set representation to ignore previous // keys that we've iterated through. @@ -153,7 +158,7 @@ public: // If there are no keys, set to end iterator value if (firstKeyIndex == std::numeric_limits::max() || - i_ == static_cast(DispatchKey::NumDispatchKeys)) { + i_ == static_cast(DispatchKey::NumDispatchKeys)) { i_ = static_cast(DispatchKey::NumDispatchKeys); return *this; } @@ -163,29 +168,38 @@ public: } self_type operator++(int) { - self_type previous_iterator = *this; + self_type previous_iterator = *this; ++(*this); return previous_iterator; } - bool operator==(const self_type& rhs) const { return i_ == rhs.i_; } - bool operator!=(const self_type& rhs) const { return i_ != rhs.i_; } - DispatchKey operator*() const { return static_cast (i_); } + bool operator==(const self_type& rhs) const { + return i_ == rhs.i_; + } + bool operator!=(const self_type& rhs) const { + return i_ != rhs.i_; + } + DispatchKey operator*() const { + return static_cast(i_); + } private: - const uint64_t *data_ptr_; + const uint64_t* data_ptr_; uint8_t i_; }; public: // Returns iterator to the first key in the set. If no keys are in the // set, then will return the end iterator. - iterator begin() const { return iterator(&repr_); } + iterator begin() const { + return iterator(&repr_); + } // We do not need to iterate beyond NumDispatchKeys so we will treat this as // the end iterator. NumDispatchKeys will always be strictly less than 64. - iterator end() const { return iterator(&repr_, static_cast(DispatchKey::NumDispatchKeys)); } - + iterator end() const { + return iterator(&repr_, static_cast(DispatchKey::NumDispatchKeys)); + } }; C10_API std::string toString(DispatchKeySet); @@ -208,7 +222,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ }); constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ - //DispatchKey::AutocastCPU, + // DispatchKey::AutocastCPU, DispatchKey::AutocastCUDA, }); @@ -224,40 +238,36 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ }); constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = - autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); + autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); // backend dispatch keys that map to DispatchKey::AutogradOther // NB: keys in this set also get associated with CompositeImplicitAutograd -constexpr DispatchKeySet autogradother_backends = DispatchKeySet({ - DispatchKey::HIP, - DispatchKey::FPGA, - DispatchKey::MSNPU, - DispatchKey::Vulkan, - DispatchKey::Metal, - DispatchKey::QuantizedCPU, - DispatchKey::QuantizedCUDA, - DispatchKey::CustomRNGKeyId, - DispatchKey::MkldnnCPU, - DispatchKey::SparseCPU, - DispatchKey::SparseCUDA, - DispatchKey::SparseHIP, - DispatchKey::SparseCsrCPU, - DispatchKey::SparseCsrCUDA, - DispatchKey::Meta -}); +constexpr DispatchKeySet autogradother_backends = DispatchKeySet( + {DispatchKey::HIP, + DispatchKey::FPGA, + DispatchKey::MSNPU, + DispatchKey::Vulkan, + DispatchKey::Metal, + DispatchKey::QuantizedCPU, + DispatchKey::QuantizedCUDA, + DispatchKey::CustomRNGKeyId, + DispatchKey::MkldnnCPU, + DispatchKey::SparseCPU, + DispatchKey::SparseCUDA, + DispatchKey::SparseHIP, + DispatchKey::SparseCsrCPU, + DispatchKey::SparseCsrCUDA, + DispatchKey::Meta}); // The set of dispatch keys that come after autograd -// n.b. this relies on the fact that AutogradOther is currently the lowest Autograd key -constexpr DispatchKeySet after_autograd_keyset = DispatchKeySet( - DispatchKeySet::FULL_AFTER, - c10::DispatchKey::AutogradOther -); +// n.b. this relies on the fact that AutogradOther is currently the lowest +// Autograd key +constexpr DispatchKeySet after_autograd_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther); // The set of dispatch keys that come after InplaceOrView -constexpr DispatchKeySet after_InplaceOrView_keyset = DispatchKeySet( - DispatchKeySet::FULL_AFTER, - c10::DispatchKey::InplaceOrView -); +constexpr DispatchKeySet after_InplaceOrView_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::InplaceOrView); // true if t is a backend dispatch key C10_API bool isBackendDispatchKey(DispatchKey t); @@ -265,8 +275,8 @@ C10_API bool isBackendDispatchKey(DispatchKey t); // Resolve alias dispatch key to DispatchKeySet if applicable C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); -// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key t, -// DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. +// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key +// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // Returns a DispatchKeySet of autograd related keys mapped to backend. @@ -291,27 +301,32 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // top of existing "backend" keys like CPU/CUDA, you need to add it // here. At the moment, autograd keys and InplaceOrView key need this // treatment; - return (s - autograd_dispatch_keyset_with_InplaceOrView - autocast_dispatch_keyset).highestPriorityTypeId(); + return (s - autograd_dispatch_keyset_with_InplaceOrView - + autocast_dispatch_keyset) + .highestPriorityTypeId(); } -template +template using is_not_DispatchKeySet = guts::negation>; -// Given a function type, constructs a function_traits type that drops the first parameter -// type if the first parameter is of type DispatchKeySet. -// NB: DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid pushing unnecessary -// arguments on the stack - see Note [ Plumbing Keys Through the Dispatcher] for details). -// If at any point in the future we need to expose this type to JIT, revisit the usage of this type alias. +// Given a function type, constructs a function_traits type that drops the first +// parameter type if the first parameter is of type DispatchKeySet. NB: +// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid +// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through +// the Dispatcher] for details). If at any point in the future we need to expose +// this type to JIT, revisit the usage of this type alias. template using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< - typename guts::infer_function_traits_t::return_type, - typename std::conditional_t< - std::is_same< - DispatchKeySet, - typename guts::typelist::head_with_default_t::parameter_types> - >::value, - guts::typelist::drop_if_nonempty_t::parameter_types, 1>, - typename guts::infer_function_traits_t::parameter_types - > ->; -} + typename guts::infer_function_traits_t::return_type, + typename std::conditional_t< + std::is_same< + DispatchKeySet, + typename guts::typelist::head_with_default_t< + void, + typename guts::infer_function_traits_t< + FuncType>::parameter_types>>::value, + guts::typelist::drop_if_nonempty_t< + typename guts::infer_function_traits_t::parameter_types, + 1>, + typename guts::infer_function_traits_t::parameter_types>>; +} // namespace c10 diff --git a/c10/core/Event.h b/c10/core/Event.h index 70d03f91bb5..eb10b8ab2ba 100644 --- a/c10/core/Event.h +++ b/c10/core/Event.h @@ -41,18 +41,18 @@ struct Event final { // Constructors Event() = delete; Event( - const DeviceType _device_type, - const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) - : impl_{_device_type, _flag} { } + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : impl_{_device_type, _flag} {} // Copy constructor and copy assignment operator (deleted) Event(const Event&) = delete; Event& operator=(const Event&) = delete; // Move constructor and move assignment operator - Event(Event&& other) : impl_{std::move(other.impl_)} { } + Event(Event&& other) : impl_{std::move(other.impl_)} {} Event& operator=(Event&& other) { - impl_.swap(std::move(other.impl_)); + impl_.swap(std::move(other.impl_)); return *this; } @@ -60,54 +60,62 @@ struct Event final { ~Event() = default; // Getters - DeviceType device_type() const noexcept { return impl_.device_type(); } - DeviceIndex device_index() const noexcept { return impl_.device_index(); } - EventFlag flag() const noexcept { return impl_.flag(); } - bool was_marked_for_recording() const noexcept { return impl_.was_marked_for_recording(); } + DeviceType device_type() const noexcept { + return impl_.device_type(); + } + DeviceIndex device_index() const noexcept { + return impl_.device_index(); + } + EventFlag flag() const noexcept { + return impl_.flag(); + } + bool was_marked_for_recording() const noexcept { + return impl_.was_marked_for_recording(); + } -/** - * Calls record() if and only if record() has never been called for this event. - * Note: because Event is not thread-safe recordOnce() may call record() - * multiple times if called from multiple threads. - */ + /** + * Calls record() if and only if record() has never been called for this + * event. Note: because Event is not thread-safe recordOnce() may call + * record() multiple times if called from multiple threads. + */ void recordOnce(const Stream& stream) { impl_.recordOnce(stream); } -/** - * Increments the event's version and enqueues a job with this version - * in the stream's work queue. When the stream process that job - * it nofifies all streams waiting on / blocked by that version of the - * event to continue and marks that version as recorded. - * */ + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it nofifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ void record(const Stream& stream) { impl_.record(stream); } -/** - * Does nothing if the event has not been scheduled to be recorded. - * If the event was previously enqueued to be recorded, a command - * to wait for the version of the event that exists at the time of this call - * is inserted in the stream's work queue. - * When the stream reaches this command it will stop processing - * additional commands until that version of the event is marked as recorded. - */ + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ void block(const Stream& stream) const { impl_.block(stream); } -/** - * Returns true if (and only if) - * (1) the event has never been scheduled to be recorded - * (2) the current version is marked as recorded. - * Returns false otherwise. - */ + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ bool query() const { return impl_.query(); } -private: + private: impl::InlineEvent impl_; }; -} // c10 +} // namespace c10 diff --git a/c10/core/GeneratorImpl.cpp b/c10/core/GeneratorImpl.cpp index 68ae9bc8012..78d30da67e3 100644 --- a/c10/core/GeneratorImpl.cpp +++ b/c10/core/GeneratorImpl.cpp @@ -13,7 +13,7 @@ namespace c10 { * GeneratorImpl class implementation */ GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set) - : device_{device_in}, key_set_(key_set) {} + : device_{device_in}, key_set_(key_set) {} /** * Clone this generator. Note that clone() is the only @@ -40,14 +40,15 @@ namespace detail { * FIXME: use std::random_device with entropy information */ #ifndef _WIN32 -static uint64_t readURandomLong() -{ +static uint64_t readURandomLong() { int randDev = open("/dev/urandom", O_RDONLY); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t randValue; TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom"); ssize_t readBytes = read(randDev, &randValue, sizeof(randValue)); - TORCH_CHECK(readBytes >= (ssize_t) sizeof(randValue), "Unable to read from /dev/urandom"); + TORCH_CHECK( + readBytes >= (ssize_t)sizeof(randValue), + "Unable to read from /dev/urandom"); close(randDev); return randValue; } @@ -58,11 +59,11 @@ static uint64_t readURandomLong() * /dev/urandom or the current time. For CUDA, gets random from * std::random_device and adds a transformation on it. * - * FIXME: The behavior in this function is from legacy code (THRandom_seed/THCRandom_seed) - * and is probably not the right thing to do, even though our tests pass. - * Figure out if tests get perturbed - * - when the same algorithm is used for all backends. Note that the current behavior is - * different for CPU, CUDA and Windows CPU. + * FIXME: The behavior in this function is from legacy code + * (THRandom_seed/THCRandom_seed) and is probably not the right thing to do, + * even though our tests pass. Figure out if tests get perturbed + * - when the same algorithm is used for all backends. Note that the current + * behavior is different for CPU, CUDA and Windows CPU. * - when using C++11 std objects, such as std::random_device * - when constructing a 64 bit seed properly, rather than static casting * a 32 bit number to 64 bit. @@ -71,11 +72,13 @@ uint64_t getNonDeterministicRandom(bool is_cuda) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t s; if (!is_cuda) { - #ifdef _WIN32 - s = (uint64_t)std::chrono::high_resolution_clock::now().time_since_epoch().count(); - #else - s = readURandomLong(); - #endif +#ifdef _WIN32 + s = (uint64_t)std::chrono::high_resolution_clock::now() + .time_since_epoch() + .count(); +#else + s = readURandomLong(); +#endif } else { std::random_device rd; // limit to 53 bits to ensure unique representation in double diff --git a/c10/core/GeneratorImpl.h b/c10/core/GeneratorImpl.h index 84e620e93a7..389bd627140 100644 --- a/c10/core/GeneratorImpl.h +++ b/c10/core/GeneratorImpl.h @@ -1,52 +1,54 @@ #pragma once #include -#include -#include #include +#include +#include #include #include -#include -#include -#include #include #include -#include #include +#include +#include +#include +#include /** * Note [Generator] * ~~~~~~~~~~~~~~~~ - * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to - * generate a seemingly random sequence of numbers, that may be later be used in creating - * a random distribution. Such an engine almost always maintains a state and requires a - * seed to start off the creation of random numbers. Often times, users have - * found it beneficial to be able to explicitly create, retain, and destroy - * PRNG states and also be able to have control over the seed value. + * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm + * to generate a seemingly random sequence of numbers, that may be later be used + * in creating a random distribution. Such an engine almost always maintains a + * state and requires a seed to start off the creation of random numbers. Often + * times, users have found it beneficial to be able to explicitly create, + * retain, and destroy PRNG states and also be able to have control over the + * seed value. * - * A Generator in ATen gives users the ability to read, write and modify a PRNG engine. - * For instance, it does so by letting users seed a PRNG engine, fork the state of the - * engine, etc. + * A Generator in ATen gives users the ability to read, write and modify a PRNG + * engine. For instance, it does so by letting users seed a PRNG engine, fork + * the state of the engine, etc. * * By default, there is one generator per device, and a device's generator is - * lazily created. A user can use the torch.Generator() api to create their own generator. - * Currently torch.Generator() can only create a CPUGeneratorImpl. + * lazily created. A user can use the torch.Generator() api to create their own + * generator. Currently torch.Generator() can only create a CPUGeneratorImpl. */ /** * Note [Acquire lock when using random generators] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * Generator and its derived classes are NOT thread-safe. Please note that most of the - * places where we have inserted locking for generators are historically based, and we - * haven't actually checked that everything is truly thread safe (and it probably isn't). - * Please use the public mutex_ when using any methods from these classes, except for the - * read-only methods. You can learn about the usage by looking into the unittests - * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. + * Generator and its derived classes are NOT thread-safe. Please note that most + * of the places where we have inserted locking for generators are historically + * based, and we haven't actually checked that everything is truly thread safe + * (and it probably isn't). Please use the public mutex_ when using any methods + * from these classes, except for the read-only methods. You can learn about the + * usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp) + * and other places where we have used lock_guard. * - * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making - * them non-thread safe and instead making the generator state splittable, to accommodate - * forks into other threads). + * TODO: Look into changing the threading semantics of Generators in ATen (e.g., + * making them non-thread safe and instead making the generator state + * splittable, to accommodate forks into other threads). */ namespace c10 { @@ -79,7 +81,9 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { // See Note [Acquire lock when using random generators] std::mutex mutex_; - DispatchKeySet key_set() const { return key_set_; } + DispatchKeySet key_set() const { + return key_set_; + } inline void set_pyobj(PyObject* pyobj) noexcept { pyobj_ = pyobj; @@ -89,12 +93,12 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { return pyobj_; } - protected: - Device device_; - DispatchKeySet key_set_; - PyObject* pyobj_ = nullptr; + protected: + Device device_; + DispatchKeySet key_set_; + PyObject* pyobj_ = nullptr; - virtual GeneratorImpl* clone_impl() const = 0; + virtual GeneratorImpl* clone_impl() const = 0; }; namespace detail { diff --git a/c10/core/GradMode.h b/c10/core/GradMode.h index 1773db201d6..1168bb1ae67 100644 --- a/c10/core/GradMode.h +++ b/c10/core/GradMode.h @@ -27,4 +27,4 @@ struct TORCH_API NoGradGuard : public AutoGradMode { NoGradGuard() : AutoGradMode(/*enabled=*/false) {} }; -} +} // namespace c10 diff --git a/c10/core/InferenceMode.cpp b/c10/core/InferenceMode.cpp index 3f55c6a6d1f..c4292859920 100644 --- a/c10/core/InferenceMode.cpp +++ b/c10/core/InferenceMode.cpp @@ -6,7 +6,8 @@ namespace c10 { thread_local bool InferenceMode_enabled = false; // Invariant: -// is_enabled() == !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView); +// is_enabled() == +// !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView); // InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor) // so it worths a separate TLS to skip the DispatchKeySet check. bool InferenceMode::is_enabled() { diff --git a/c10/core/InferenceMode.h b/c10/core/InferenceMode.h index 1df2dde8dcd..685f2e82461 100644 --- a/c10/core/InferenceMode.h +++ b/c10/core/InferenceMode.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include namespace c10 { @@ -10,13 +10,14 @@ namespace c10 { // construction, and sets it back to the original value upon destruction. struct TORCH_API InferenceMode { // Note [Expected TLS state in InferenceMode]: - // InferenceMode: InplaceOrView not in raw_local_dispatch_key_set.included(), + // InferenceMode: InplaceOrView not in + // raw_local_dispatch_key_set.included(), // Autograd in raw_local_dispatch_key_set.excluded() // GradMode is disabled. // NormalMode: InplaceOrView in raw_local_dispatch_key_set.included(), // Autograd not in raw_local_dispatch_key_set.excluded() - // GradMode is enabled by default unless toggled manually through - // other APIs, e.g. NoGradGuard. + // GradMode is enabled by default unless toggled manually + // through other APIs, e.g. NoGradGuard. // // Invariant: // - InplaceOrView is never in the excluded set @@ -25,8 +26,8 @@ struct TORCH_API InferenceMode { // // 1. Why do we put InplaceOrView in included set outside InferenceMode? // - // Inplace update to inference tensor outside InferenceMode is not allowed. - // See Note [Inplace update inference tensor] for more details. + // Inplace update to inference tensor outside InferenceMode is not + // allowed. See Note [Inplace update inference tensor] for more details. // Without going through InplaceOrView kernel, we cannot throw error // for `inference_tensor.add_(1)` case. // @@ -48,14 +49,17 @@ struct TORCH_API InferenceMode { // version of NoGradGuard. All runtime checks using GradMode::is_enabled() // are applicable to InferenceMode as well, e.g. // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. - InferenceMode(bool enabled=true): prev_mode(InferenceMode::is_enabled()), - prev_keyset(c10::impl::tls_local_dispatch_key_set()), - grad_mode(at::AutoGradMode(!enabled)) { + InferenceMode(bool enabled = true) + : prev_mode(InferenceMode::is_enabled()), + prev_keyset(c10::impl::tls_local_dispatch_key_set()), + grad_mode(at::AutoGradMode(!enabled)) { set_enabled(enabled); - DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView) - : prev_keyset.included_.add(c10::DispatchKey::InplaceOrView); - DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) - : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); + DispatchKeySet included = enabled + ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView) + : prev_keyset.included_.add(c10::DispatchKey::InplaceOrView); + DispatchKeySet excluded = enabled + ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) + : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); c10::impl::PODLocalDispatchKeySet cur_keyset; cur_keyset.set_included(included); cur_keyset.set_excluded(excluded); @@ -71,9 +75,9 @@ struct TORCH_API InferenceMode { // ThreadLocalState.cpp. static void set_enabled(bool enabled); - private: - bool prev_mode; - c10::impl::LocalDispatchKeySet prev_keyset; - at::AutoGradMode grad_mode; + private: + bool prev_mode; + c10::impl::LocalDispatchKeySet prev_keyset; + at::AutoGradMode grad_mode; }; } // namespace c10 diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index 19b93b9f590..ba4e056e1e6 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include @@ -17,14 +17,20 @@ // should be in channels_last format // // Contiguous: -// Regardless of input tensors format, the output should be contiguous Tensor. +// Regardless of input tensors format, the output should be contiguous +// Tensor. // // ChannelsLast: -// Regardless of input tensors format, the output should be in channels_last format. - +// Regardless of input tensors format, the output should be in channels_last +// format. namespace c10 { -enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d }; +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d +}; // If you are seeing this, it means that this call site was not checked if // the memory format could be preserved, and it was switched to old default @@ -52,7 +58,8 @@ inline std::ostream& operator<<( } } -// Note: Hardcoded the channel last stride indices here to get better performance +// Note: Hardcoded the channel last stride indices here to get better +// performance inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { std::vector strides(sizes.size()); switch (sizes.size()) { @@ -68,7 +75,8 @@ inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { strides[1] = strides[2] * sizes[2]; return strides; default: - TORCH_INTERNAL_ASSERT(false, "ChannelsLast2d doesn't support size ", sizes.size()); + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast2d doesn't support size ", sizes.size()); } } @@ -89,7 +97,8 @@ inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { strides[1] = strides[2] * sizes[2]; return strides; default: - TORCH_INTERNAL_ASSERT(false, "ChannelsLast3d doesn't support size ", sizes.size()); + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast3d doesn't support size ", sizes.size()); } } @@ -100,12 +109,16 @@ inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { // will be a constant array and we can access it using constant index number, // the compiler will fully unroll the loop on strides indices to gain a better // performance. -// 2. No error check in helper function, caller ensures the correctness of the input -// 3. All helper functions have similar comments, only 1st helper function is commented here. -inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArrayRef strides) { +// 2. No error check in helper function, caller ensures the correctness of the +// input +// 3. All helper functions have similar comments, only 1st helper function is +// commented here. +inline bool is_channels_last_strides_2d_s4( + const IntArrayRef sizes, + const IntArrayRef strides) { int64_t min = 0; // special case for trivial C dimension. default to NCHW - if (strides[1]==0) { + if (strides[1] == 0) { return false; } // loop strides indices @@ -121,8 +134,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr // N111 tensor with identical strides for size 1 dimension; // Two cases could lead us here: // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) - // b. N11W contiguous Tensor sliced on the W-dimension. ([N,1,1,1]@[W,W,W,W]) - if (d==0 && min==strides[1]) { + // b. N11W contiguous Tensor sliced on the W-dimension. + // ([N,1,1,1]@[W,W,W,W]) + if (d == 0 && min == strides[1]) { return false; } // This is necessary to: @@ -140,7 +154,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr return true; } -inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArrayRef strides) { +inline bool is_channels_last_strides_3d_s5( + const IntArrayRef sizes, + const IntArrayRef strides) { int64_t min = 0; if (strides[1] == 0) { return false; @@ -209,10 +225,13 @@ inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArr // issues in our tests. // // We use Channels Last 2d as an example above. -// This is a general problem for all the is_channels_last_strides_xd implementation. -// Please check the helper functions (is_channels_last_strides_*d_s*) for more details. +// This is a general problem for all the is_channels_last_strides_xd +// implementation. Please check the helper functions +// (is_channels_last_strides_*d_s*) for more details. -inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayRef strides) { +inline bool is_channels_last_strides_2d( + const IntArrayRef sizes, + const IntArrayRef strides) { switch (sizes.size()) { case 4: return is_channels_last_strides_2d_s4(sizes, strides); @@ -224,7 +243,9 @@ inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayR } } -inline bool is_channels_last_strides_3d(const IntArrayRef sizes, const IntArrayRef strides) { +inline bool is_channels_last_strides_3d( + const IntArrayRef sizes, + const IntArrayRef strides) { switch (sizes.size()) { case 5: return is_channels_last_strides_3d_s5(sizes, strides); diff --git a/c10/core/QEngine.h b/c10/core/QEngine.h index e69e24ed0ca..ac092193d92 100644 --- a/c10/core/QEngine.h +++ b/c10/core/QEngine.h @@ -31,9 +31,7 @@ inline std::string toString(QEngine qengine) { return "QNNPACK"; default: TORCH_CHECK( - false, - "Unrecognized Quantized Engine: ", - static_cast(qengine)); + false, "Unrecognized Quantized Engine: ", static_cast(qengine)); } } diff --git a/c10/core/QScheme.h b/c10/core/QScheme.h index 0a1246830e1..957618d74fc 100644 --- a/c10/core/QScheme.h +++ b/c10/core/QScheme.h @@ -24,12 +24,13 @@ constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE; constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE; constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC; constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC; -constexpr auto kPerChannelAffineFloatQParams = QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS; +constexpr auto kPerChannelAffineFloatQParams = + QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS; constexpr int COMPILE_TIME_NUM_QSCHEMES = - static_cast(QScheme::COMPILE_TIME_NUM_QSCHEMES); + static_cast(QScheme::COMPILE_TIME_NUM_QSCHEMES); inline std::string toString(QScheme qscheme) { - switch(qscheme) { + switch (qscheme) { case kPerTensorAffine: return "per_tensor_affine"; case kPerChannelAffine: diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 203b544924e..dd1f95813c3 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -3,7 +3,9 @@ namespace c10 { Scalar Scalar::operator-() const { - TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not supported."); + TORCH_CHECK( + !isBoolean(), + "torch boolean negative, the `-` operator, is not supported."); if (isFloatingPoint()) { return Scalar(-v.d); } else if (isComplex()) { @@ -31,4 +33,4 @@ Scalar Scalar::log() const { } } -} // namespace c10 +} // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 368228e8202..802bf17e041 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -4,8 +4,8 @@ #include #include #include -#include #include +#include #include #include @@ -26,8 +26,8 @@ class C10_API Scalar { public: Scalar() : Scalar(int64_t(0)) {} -#define DEFINE_IMPLICIT_CTOR(type, name) \ - Scalar(type vv) : Scalar(vv, true) { } +#define DEFINE_IMPLICIT_CTOR(type, name) \ + Scalar(type vv) : Scalar(vv, true) {} AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR) AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) @@ -45,18 +45,18 @@ class C10_API Scalar { v.i = convert(vv); } -#define DEFINE_ACCESSOR(type, name) \ - type to##name() const { \ - if (Tag::HAS_d == tag) { \ - return checked_convert(v.d, #type); \ - } else if (Tag::HAS_z == tag) { \ - return checked_convert>( \ - v.z, #type); \ - } if (Tag::HAS_b == tag) { \ - return checked_convert(v.i, #type); \ - } else { \ - return checked_convert(v.i, #type); \ - } \ +#define DEFINE_ACCESSOR(type, name) \ + type to##name() const { \ + if (Tag::HAS_d == tag) { \ + return checked_convert(v.d, #type); \ + } else if (Tag::HAS_z == tag) { \ + return checked_convert>(v.z, #type); \ + } \ + if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ + } else { \ + return checked_convert(v.i, #type); \ + } \ } // TODO: Support ComplexHalf accessor @@ -71,7 +71,8 @@ class C10_API Scalar { return Tag::HAS_d == tag; } - C10_DEPRECATED_MESSAGE("isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") + C10_DEPRECATED_MESSAGE( + "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") bool isIntegral() const { return Tag::HAS_i == tag; } @@ -90,7 +91,9 @@ class C10_API Scalar { Scalar conj() const; Scalar log() const; - template::value, int>::type = 0> + template < + typename T, + typename std::enable_if::value, int>::type = 0> bool equal(T num) const { if (isComplex()) { auto val = v.z; @@ -105,7 +108,9 @@ class C10_API Scalar { } } - template::value, int>::type = 0> + template < + typename T, + typename std::enable_if::value, int>::type = 0> bool equal(T num) const { if (isComplex()) { return v.z == num; @@ -142,26 +147,30 @@ class C10_API Scalar { } private: - template::value && ! std::is_same::value, bool>::type* = - nullptr> - Scalar(T vv, bool) : tag(Tag::HAS_i) { - v.i = convert(vv); - } + template < + typename T, + typename std::enable_if< + std::is_integral::value && !std::is_same::value, + bool>::type* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_i) { + v.i = convert(vv); + } - template::value && !c10::is_complex::value, bool>::type* = - nullptr> - Scalar(T vv, bool) : tag(Tag::HAS_d) { - v.d = convert(vv); - } + template < + typename T, + typename std::enable_if< + !std::is_integral::value && !c10::is_complex::value, + bool>::type* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_d) { + v.d = convert(vv); + } - template::value, bool>::type* = - nullptr> - Scalar(T vv, bool) : tag(Tag::HAS_z) { - v.z = convert(vv); - } + template < + typename T, + typename std::enable_if::value, bool>::type* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_z) { + v.z = convert(vv); + } // We can't set v in the initializer list using the // syntax v{ .member = ... } because it doesn't work on MSVC @@ -172,7 +181,7 @@ class C10_API Scalar { double d; int64_t i; c10::complex z; - v_t(){} // default constructor + v_t() {} // default constructor } v; }; @@ -182,10 +191,10 @@ inline T Scalar::to() const { throw std::runtime_error("to() cast to unexpected type."); } -#define DEFINE_TO(T, name) \ - template <> \ - inline T Scalar::to() const { \ - return to##name(); \ +#define DEFINE_TO(T, name) \ + template <> \ + inline T Scalar::to() const { \ + return to##name(); \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO) #undef DEFINE_TO diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 76af2eb0f46..3cabbbeed5b 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -1,14 +1,14 @@ #pragma once #include -#include +#include #include +#include +#include #include #include -#include -#include #include -#include +#include #include #include @@ -44,7 +44,6 @@ namespace c10 { _(at::BFloat16, BFloat16) /* 15 */ \ _(c10::quint4x2, QUInt4x2) /* 16 */ - // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... @@ -62,7 +61,6 @@ namespace c10 { _(bool, Bool) \ _(at::BFloat16, BFloat16) - enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM) @@ -71,7 +69,8 @@ enum class ScalarType : int8_t { NumOptions }; -constexpr uint16_t NumScalarTypes = static_cast(ScalarType::NumOptions); +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); namespace impl { @@ -80,20 +79,20 @@ namespace impl { template struct ScalarTypeToCPPType; -#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ -template<> \ -struct ScalarTypeToCPPType { \ - using type = cpp_type; \ - \ - /* This is a workaround for the CUDA bug which prevents */ \ - /* ::detail::ScalarTypeToCType::type being used directly due to */ \ - /* ambiguous reference which can't to be resolved. For some reason it */ \ - /* cant pick between at::detail and at::cuda::detail. */ \ - /* For repro example, please see: */ \ - /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ - /* TODO: remove once the bug is fixed. */ \ - static type t; \ -}; +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + \ + /* This is a workaround for the CUDA bug which prevents */ \ + /* ::detail::ScalarTypeToCType::type being used directly due to */ \ + /* ambiguous reference which can't to be resolved. For some reason it */ \ + /* cant pick between at::detail and at::cuda::detail. */ \ + /* For repro example, please see: */ \ + /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ + /* TODO: remove once the bug is fixed. */ \ + static type t; \ + }; AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) @@ -104,12 +103,12 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) template struct CppTypeToScalarType; -#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ - template<> \ - struct CppTypeToScalarType: \ - std::integral_constant \ - {}; +#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std:: \ + integral_constant { \ + }; AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) @@ -131,47 +130,59 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(float, Float) \ _(double, Double) -#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE) +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), \ + SCALARTYPE) -#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) -#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) \ - _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), SCALARTYPE3) +#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype( \ + ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) -#define AT_FORALL_QINT_TYPES(_) \ - _(c10::qint8, QInt8) \ - _(c10::quint8, QUInt8) \ - _(c10::qint32, QInt32) \ +#define AT_FORALL_QINT_TYPES(_) \ + _(c10::qint8, QInt8) \ + _(c10::quint8, QUInt8) \ + _(c10::qint32, QInt32) \ _(c10::quint4x2, QUInt4x2) -#define AT_FORALL_COMPLEX_TYPES(_) \ - _(c10::complex, ComplexFloat) \ +#define AT_FORALL_COMPLEX_TYPES(_) \ + _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) #define DEFINE_CONSTANT(_, name) \ @@ -206,7 +217,8 @@ static inline size_t elementSize(ScalarType t) { #undef CASE_ELEMENTSIZE_CASE } -C10_DEPRECATED_MESSAGE("isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") +C10_DEPRECATED_MESSAGE( + "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") static inline bool isIntegralType(ScalarType t) { return ( t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || @@ -214,9 +226,9 @@ static inline bool isIntegralType(ScalarType t) { } static inline bool isIntegralType(ScalarType t, bool includeBool) { - bool isIntegral = ( - t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || - t == ScalarType::Long || t == ScalarType::Short); + bool isIntegral = + (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || + t == ScalarType::Long || t == ScalarType::Short); return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral; } @@ -235,7 +247,8 @@ static inline bool isComplexType(ScalarType t) { static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types - return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2; + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2; } static inline ScalarType toQIntType(ScalarType t) { @@ -268,20 +281,20 @@ static inline ScalarType toUnderlying(ScalarType t) { static inline bool isSignedType(ScalarType t) { TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types"); - #define CASE_SIGNED(ctype, name) \ - case ScalarType::name: \ - return std::numeric_limits::is_signed; +#define CASE_SIGNED(ctype, name) \ + case ScalarType::name: \ + return std::numeric_limits::is_signed; switch (t) { case ScalarType::ComplexHalf: case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: return true; - AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED) + AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED) default: TORCH_CHECK(false, "Unknown ScalarType"); } - #undef CASE_SIGNED +#undef CASE_SIGNED } static inline bool isUnderlying(ScalarType type, ScalarType qtype) { @@ -323,7 +336,8 @@ static inline ScalarType toComplexType(ScalarType t) { // see tensor_attributes.rst for detailed explanation and examples // of casting rules. static inline bool canCast(const ScalarType from, const ScalarType to) { - // We disallow complex -> non complex, e.g., float_tensor *= complex is disallowed. + // We disallow complex -> non complex, e.g., float_tensor *= complex is + // disallowed. if (isComplexType(from) && !isComplexType(to)) { return false; } @@ -333,13 +347,14 @@ static inline bool canCast(const ScalarType from, const ScalarType to) { } // Treat bool as a distinct "category," to be consistent with type promotion - // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same category - // as `bool_tensor`, we would not promote. - // Differing categories implies `bool_tensor += 5` is disallowed. + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same + // category as `bool_tensor`, we would not promote. Differing categories + // implies `bool_tensor += 5` is disallowed. // // NB: numpy distinguishes "unsigned" as a category to get the desired // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: - // * We don't want the performance hit of checking the runtime sign of Scalars. + // * We don't want the performance hit of checking the runtime sign of + // Scalars. // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. if (from != ScalarType::Bool && to == ScalarType::Bool) { return false; @@ -373,7 +388,8 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { } if (isQIntType(a) || isQIntType(b)) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ", toString(a), " ", @@ -385,23 +401,23 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { // corrent values for the type promotions in complex type cases. static constexpr ScalarType _promoteTypesLookup[static_cast( ScalarType::NumOptions)][static_cast(ScalarType::NumOptions)] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8}, - /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf}, - /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf}, + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8}, + /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf}, + /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf}, }; return _promoteTypesLookup[static_cast(a)][static_cast(b)]; } diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h index b6e7f6cf199..6d4946b29bc 100644 --- a/c10/core/ScalarTypeToTypeMeta.h +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -26,7 +26,8 @@ static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { /** * typeMetaToScalarType(), lifted to optional */ -static inline optional optTypeMetaToScalarType(optional type_meta) { +static inline optional optTypeMetaToScalarType( + optional type_meta) { if (!type_meta.has_value()) { return c10::nullopt; } diff --git a/c10/core/Storage.cpp b/c10/core/Storage.cpp index 9ba79a92372..1361c8186fa 100644 --- a/c10/core/Storage.cpp +++ b/c10/core/Storage.cpp @@ -1,5 +1,3 @@ #include -namespace c10 { - -} // namespace c10 +namespace c10 {} // namespace c10 diff --git a/c10/core/Storage.h b/c10/core/Storage.h index e2c36713cf2..f8df22b55e6 100644 --- a/c10/core/Storage.h +++ b/c10/core/Storage.h @@ -9,7 +9,8 @@ struct C10_API Storage { struct use_byte_size_t {}; Storage() {} - Storage(c10::intrusive_ptr ptr) : storage_impl_(std::move(ptr)) {} + Storage(c10::intrusive_ptr ptr) + : storage_impl_(std::move(ptr)) {} // Allocates memory buffer using given allocator and creates a storage with it Storage( @@ -53,10 +54,14 @@ struct C10_API Storage { } template - T* data() const { return storage_impl_->data(); } + T* data() const { + return storage_impl_->data(); + } template - T* unsafe_data() const { return storage_impl_->unsafe_data(); } + T* unsafe_data() const { + return storage_impl_->unsafe_data(); + } // TODO: remove later void set_nbytes(size_t size_bytes) const { @@ -134,7 +139,8 @@ struct C10_API Storage { size_t capacity, DeleterFnPtr d = nullptr) { if (!storage_impl_.unique()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "UniqueStorageShareExternalPointer can only be called when use_count == 1"); } storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d); @@ -144,7 +150,8 @@ struct C10_API Storage { at::DataPtr&& data_ptr, size_t capacity) { if (!storage_impl_.unique()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "UniqueStorageShareExternalPointer can only be called when use_count == 1"); } storage_impl_->UniqueStorageShareExternalPointer( diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index 96897e57513..ff29b68dc4d 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -20,13 +20,13 @@ namespace c10 { // but a lot of things won't work correctly, including: // // - An ordinary deleter on such a storage is wrong, because normal deleters -// assume unique ownership, but if you have two storages at the same data, that -// implies there is some sort of shared ownership. So your deleter would have to -// actually be internally doing some sort of refcount thing +// assume unique ownership, but if you have two storages at the same data, +// that implies there is some sort of shared ownership. So your deleter would +// have to actually be internally doing some sort of refcount thing // - Deepcopy in Python side relies on storage equality and not data pointer // equality; so if there are two separate storages pointing to the same data, -// the data will actually get duplicated in that case (one data ptr before, two -// data ptrs after) +// the data will actually get duplicated in that case (one data ptr before, +// two data ptrs after) // - Version counts won't work correctly, because we do all VC tracking at the // level of storages (unless you explicitly disconnect the VC with detach); // mutation because data pointers are the same are totally untracked diff --git a/c10/core/Stream.h b/c10/core/Stream.h index 62d5261534e..c149dc260e0 100644 --- a/c10/core/Stream.h +++ b/c10/core/Stream.h @@ -55,10 +55,11 @@ using StreamId = int32_t; * wrapper classes which provide this functionality, e.g., CUDAStream. */ class Stream final { -private: + private: Device device_; StreamId id_; -public: + + public: enum Unsafe { UNSAFE }; enum Default { DEFAULT }; @@ -69,16 +70,13 @@ public: /// we don't require backends to give any guarantees about non-zero /// StreamIds; they are welcome to allocate in whatever way they like. explicit Stream(Unsafe, Device device, StreamId id) - : device_(device) - , id_(id) {} + : device_(device), id_(id) {} /// Construct the default stream of a Device. The default stream is /// NOT the same as the current stream; default stream is a fixed stream /// that never changes, whereas the current stream may be changed by /// StreamGuard. - explicit Stream(Default, Device device) - : device_(device) - , id_(0) {} + explicit Stream(Default, Device device) : device_(device), id_(0) {} bool operator==(const Stream& other) const noexcept { return this->device_ == other.device_ && this->id_ == other.id_; @@ -87,10 +85,18 @@ public: return !(*this == other); } - Device device() const noexcept { return device_; } - DeviceType device_type() const noexcept { return device_.type(); } - DeviceIndex device_index() const noexcept { return device_.index(); } - StreamId id() const noexcept { return id_; } + Device device() const noexcept { + return device_; + } + DeviceType device_type() const noexcept { + return device_.type(); + } + DeviceIndex device_index() const noexcept { + return device_.index(); + } + StreamId id() const noexcept { + return id_; + } // Enqueues a wait instruction in the stream's work queue. // This instruction is a no-op unless the event is marked @@ -116,10 +122,10 @@ public: static_assert(sizeof(StreamId) == 4, "DeviceIndex is not 32-bit"); // Concat these together into a 64-bit integer // See Note [Hazard when concatenating signed integers] - uint64_t bits = - static_cast(static_cast(device_type())) << 48 - | static_cast(static_cast(device_index())) << 32 - | static_cast(static_cast(id())); + uint64_t bits = static_cast(static_cast(device_type())) + << 48 | + static_cast(static_cast(device_index())) << 32 | + static_cast(static_cast(id())); return bits; } @@ -145,10 +151,10 @@ C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s); } // namespace c10 namespace std { - template <> - struct hash { - size_t operator()(c10::Stream s) const noexcept { - return std::hash{}(s.pack()); - } - }; +template <> +struct hash { + size_t operator()(c10::Stream s) const noexcept { + return std::hash{}(s.pack()); + } +}; } // namespace std diff --git a/c10/core/StreamGuard.h b/c10/core/StreamGuard.h index c47925cd552..8a4116f80f0 100644 --- a/c10/core/StreamGuard.h +++ b/c10/core/StreamGuard.h @@ -47,7 +47,9 @@ struct StreamGuard { /// WARNING: reset_stream does NOT preserve previously set streams on /// different devices. If you need to set streams on multiple devices /// on , use MultiStreamGuard instead. - void reset_stream(Stream stream) { guard_.reset_stream(stream); } + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } /// Returns the stream that was set at the time the guard was constructed. Stream original_stream() const { @@ -62,13 +64,17 @@ struct StreamGuard { /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device/reset_device/set_index. - Device current_device() const { return guard_.current_device(); } + Device current_device() const { + return guard_.current_device(); + } /// Returns the device that was set at the most recent reset_stream(), /// or otherwise the device at construction time. - Device original_device() const { return guard_.original_device(); } + Device original_device() const { + return guard_.original_device(); + } -private: + private: c10::impl::InlineStreamGuard guard_; }; @@ -88,7 +94,8 @@ struct OptionalStreamGuard { /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, /// if the passed stream is not nullopt. - explicit OptionalStreamGuard(optional stream_opt) : guard_(stream_opt) {} + explicit OptionalStreamGuard(optional stream_opt) + : guard_(stream_opt) {} /// Copy is disallowed OptionalStreamGuard(const OptionalStreamGuard&) = delete; @@ -105,21 +112,30 @@ struct OptionalStreamGuard { /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// Initializes the guard if it was not previously initialized. - void reset_stream(Stream stream) { guard_.reset_stream(stream); } + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. - optional original_stream() const { return guard_.original_stream(); } + optional original_stream() const { + return guard_.original_stream(); + } /// Returns the most recent stream that was set using this stream guard, - /// either from construction, or via reset_stream, if the guard is initialized, - /// or nullopt if the guard is uninitialized. - optional current_stream() const { return guard_.current_stream(); } + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + optional current_stream() const { + return guard_.current_stream(); + } - /// Restore the original device and stream, resetting this guard to uninitialized state. - void reset() { guard_.reset(); } + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } -private: + private: c10::impl::InlineOptionalStreamGuard guard_; }; @@ -142,7 +158,7 @@ struct MultiStreamGuard { // See Note [Move assignment for RAII guards is tricky] MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete; -private: + private: c10::impl::InlineMultiStreamGuard guard_; }; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 0e746b13706..98a264f5b4b 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -1,10 +1,10 @@ #include #include +#include #include #include #include -#include // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_DEFINE_bool( @@ -21,7 +21,7 @@ C10_DEFINE_int64( namespace c10 { -const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed = +const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed = "is not allowed on a Tensor created from .data or .detach().\n" "If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n" "without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n" @@ -32,7 +32,8 @@ const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed = " x.set_(y)"; at::Tensor& TensorImpl::mutable_grad() { - if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make(); + if (!autograd_meta_) + autograd_meta_ = impl::GetAutogradMetaFactory()->make(); return autograd_meta_->mutable_grad(); } @@ -43,18 +44,26 @@ const at::Tensor& TensorImpl::grad() const { // is not so easy to fix right now because the mutable counterpart of // this function must keep working so that "x.grad() = ..." keeps working // (part of public API). - if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor(); + if (!autograd_meta_) + return impl::GetAutogradMetaFactory()->undefined_tensor(); return autograd_meta_->grad(); } -const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) const { +const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) + const { // See TensorImpl::grad() above for explanation about the line below - if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor(); + if (!autograd_meta_) + return impl::GetAutogradMetaFactory()->undefined_tensor(); return autograd_meta_->fw_grad(level, self); } -void TensorImpl::_set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) { - if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make(); +void TensorImpl::_set_fw_grad( + const at::Tensor& new_grad, + const at::Tensor& self, + uint64_t level, + bool is_inplace_op) { + if (!autograd_meta_) + autograd_meta_ = impl::GetAutogradMetaFactory()->make(); autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op); } @@ -63,7 +72,11 @@ TensorImpl::TensorImpl( DispatchKeySet key_set, const caffe2::TypeMeta data_type) // Use std::forward to suppress static analyzer false positive. - : TensorImpl(std::forward(storage), key_set, data_type, storage.device()) {} + : TensorImpl( + std::forward(storage), + key_set, + data_type, + storage.device()) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorImpl::TensorImpl( @@ -84,23 +97,29 @@ TensorImpl::TensorImpl( } } -TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) +TensorImpl::TensorImpl( + DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + c10::optional device_opt) // NOLINTNEXTLINE(performance-move-const-arg) : TensorImpl({}, key_set, data_type, std::move(device_opt)) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, - c10::optional device_opt) +TensorImpl::TensorImpl( + Storage&& storage, + DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + c10::optional device_opt) : storage_(std::move(storage)), storage_offset_(0), numel_(0), data_type_(data_type), device_opt_(device_opt) { - init_bitfields(); if (!key_set.empty()) { - TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value()); + TORCH_INTERNAL_ASSERT( + data_type == ScalarType::Undefined || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it C10_LOG_API_USAGE_ONCE("tensor.create"); } @@ -115,12 +134,13 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: // Inference tensor doesn't have autograd related keys. if (inference_mode) { - // See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys. - // Normally key_set only contains backend keys but we do the substraction - // here to make sure. + // See Note [Expected TLS state in InferenceMode] for why we exclude + // Autograd & InplaceOrView keys. Normally key_set only contains backend + // keys but we do the substraction here to make sure. key_set_ = key_set - c10::autograd_dispatch_keyset_with_InplaceOrView; } else { - // TODO: Ideally we only add AutogradBackend key when the tensor requires grad. + // TODO: Ideally we only add AutogradBackend key when the tensor requires + // grad. // See Note [Dream: skip VariableType kernel when requires_grad=false] key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k); } @@ -130,8 +150,8 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: version_counter_ = VariableVersion(/*version=*/0); } - // we would also like to check that non-cpu devices have an index, but some Caffe2 operators create - // Storages with default devices. + // we would also like to check that non-cpu devices have an index, but some + // Caffe2 operators create Storages with default devices. } #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY @@ -151,15 +171,14 @@ void TensorImpl::HandleResize() { if (reserved_) { // If tensor is reserved then don't claim its memeory unless nbytes() // is smaller than new size - reset_tensor = storage_.nbytes() < - (storage_offset_ + numel_) * data_type_.itemsize(); + reset_tensor = + storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize(); } else { reset_tensor = storage_.nbytes() < - (storage_offset_ + numel_) * data_type_.itemsize() || - !FLAGS_caffe2_keep_on_shrink || - storage_.nbytes() - - (storage_offset_ + numel_) * data_type_.itemsize() > - static_cast(FLAGS_caffe2_max_keep_on_shrink_memory); + (storage_offset_ + numel_) * data_type_.itemsize() || + !FLAGS_caffe2_keep_on_shrink || + storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() > + static_cast(FLAGS_caffe2_max_keep_on_shrink_memory); } if (reset_tensor && storage_initialized()) { @@ -190,20 +209,19 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance switch (sizes_and_strides_.size()) { - case 4: - { - int64_t expected = 1; - for (auto& d : {1, 3, 2, 0}) { - const auto size_d = sizes_and_strides_.size_at_unchecked(d); - if (size_d != 1) { - if (sizes_and_strides_.stride_at_unchecked(d) != expected) { - return false; - } - expected *= size_d; + case 4: { + int64_t expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { + return false; } + expected *= size_d; } - return true; } + return true; + } // NOLINTNEXTLINE(bugprone-branch-clone) case 3: // TODO dim == 3 case will be enabled once it is fully tested @@ -218,20 +236,19 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const { // compiler fully unroll the loop to get better performance switch (sizes_and_strides_.size()) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - case 5: - { - int64_t expected = 1; - for (auto& d : {1, 4, 3, 2, 0}) { - const auto size_d = sizes_and_strides_.size_at_unchecked(d); - if (size_d != 1) { - if (sizes_and_strides_.stride_at_unchecked(d) != expected) { - return false; - } - expected *= size_d; + case 5: { + int64_t expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { + return false; } + expected *= size_d; } - return true; } + return true; + } // NOLINTNEXTLINE(bugprone-branch-clone) case 4: // TODO dim == 4 case will be enabled once it is fully tested @@ -242,34 +259,38 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const { } bool TensorImpl::compute_strides_like_channels_last_2d() const { - return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides()); + return is_channels_last_strides_2d( + TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_strides_like_channels_last_3d() const { - return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides()); + return is_channels_last_strides_3d( + TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_non_overlapping_and_dense() const { if (dim() == 1) { - return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1; + return sizes_and_strides_.size_at_unchecked(0) < 2 || + sizes_and_strides_.stride_at_unchecked(0) == 1; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - SmallVector perm; + SmallVector perm; perm.resize(dim()); - for (int64_t i = 0; i < dim(); i ++) { + for (int64_t i = 0; i < dim(); i++) { perm[i] = i; } // Sort by strides, leaving 0 and 1 sized dims at the end of the array std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { - if (sizes_and_strides_.size_at_unchecked(a) < 2) { - return false; - } else if (sizes_and_strides_.size_at_unchecked(b) < 2) { - return true; - } - return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b); + if (sizes_and_strides_.size_at_unchecked(a) < 2) { + return false; + } else if (sizes_and_strides_.size_at_unchecked(b) < 2) { + return true; + } + return sizes_and_strides_.stride_at_unchecked(a) < + sizes_and_strides_.stride_at_unchecked(b); }); auto require_stride = 1; - for (int64_t i = 0; i < dim(); i ++) { + for (int64_t i = 0; i < dim(); i++) { const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]); if (size_perm_i < 2) { return true; @@ -312,16 +333,23 @@ bool TensorImpl::has_storage() const { #endif void TensorImpl::throw_storage_access_error() const { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot access storage of ", tensorimpl_type_name()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Cannot access storage of ", tensorimpl_type_name()); } -bool TensorImpl::is_contiguous_nondefault_policy_impl(at::MemoryFormat memory_format) const { - if (has_contiguity_ == static_cast(HasContiguityPolicy::ContiguityNotSupported)) { +bool TensorImpl::is_contiguous_nondefault_policy_impl( + at::MemoryFormat memory_format) const { + if (has_contiguity_ == + static_cast(HasContiguityPolicy::ContiguityNotSupported)) { TORCH_CHECK_NOT_IMPLEMENTED( - false, "Tensors of type ", tensorimpl_type_name(), + false, + "Tensors of type ", + tensorimpl_type_name(), " do not have is_contiguous"); } else { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_contiguity_ == static_cast(HasContiguityPolicy::CustomBehavior)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + has_contiguity_ == + static_cast(HasContiguityPolicy::CustomBehavior)); return is_contiguous_custom(memory_format); } } @@ -343,10 +371,11 @@ at::DataPtr PlacementDeleteContext::makeDataPtr( size_t size, at::Device device) { auto* ptr = data_ptr.get(); - return {ptr, - new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size), - &deletePlacementDeleteContext, - device}; + return { + ptr, + new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size), + &deletePlacementDeleteContext, + device}; } // NOLINTNEXTLINE(modernize-use-equals-default) @@ -359,10 +388,14 @@ AutogradMetaInterface::~AutogradMetaInterface() {} // used in C++ frontend. Forbidding it inside InferenceMode will force users // to delete these setter code in their code which is not ideal. void TensorImpl::set_requires_grad(bool requires_grad) { - TORCH_CHECK(!(requires_grad && is_inference_tensor() && !c10::InferenceMode::is_enabled()), - "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); - if (!requires_grad && !autograd_meta_) return; - if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make(); + TORCH_CHECK( + !(requires_grad && is_inference_tensor() && + !c10::InferenceMode::is_enabled()), + "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); + if (!requires_grad && !autograd_meta_) + return; + if (!autograd_meta_) + autograd_meta_ = impl::GetAutogradMetaFactory()->make(); // NB: In principle, setting requires_grad to false could result in // the AutogradMeta becoming equal to a default constructed state, // in which case we could apply the nullptr AutogradMeta optimization @@ -376,11 +409,13 @@ void TensorImpl::set_requires_grad(bool requires_grad) { } bool TensorImpl::requires_grad() const { - if (!autograd_meta_) return false; + if (!autograd_meta_) + return false; return autograd_meta_->requires_grad(); } -void TensorImpl::set_autograd_meta(std::unique_ptr autograd_meta) { +void TensorImpl::set_autograd_meta( + std::unique_ptr autograd_meta) { // NB: autograd_meta may be null! That just means it's the default // constructor autograd_meta_ = std::move(autograd_meta); @@ -396,7 +431,9 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive( // No need to populate Storage; copy_tensor_metadata will do it for us. - key_set_, data_type_, device_opt_); + key_set_, + data_type_, + device_opt_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -412,7 +449,9 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive( // No need to populate Storage; copy_tensor_metadata will do it for us. - key_set_, data_type_, device_opt_); + key_set_, + data_type_, + device_opt_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -435,15 +474,19 @@ void TensorImpl::copy_tensor_metadata_except_version_counter( dest_impl->key_set_ = src_impl->key_set_; dest_impl->is_contiguous_ = src_impl->is_contiguous_; dest_impl->has_contiguity_ = src_impl->has_contiguity_; - dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_; - dest_impl->is_channels_last_3d_contiguous_ = src_impl->is_channels_last_3d_contiguous_; + dest_impl->is_channels_last_contiguous_ = + src_impl->is_channels_last_contiguous_; + dest_impl->is_channels_last_3d_contiguous_ = + src_impl->is_channels_last_3d_contiguous_; dest_impl->is_channels_last_ = src_impl->is_channels_last_; dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_; - dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_; + dest_impl->is_non_overlapping_and_dense_ = + src_impl->is_non_overlapping_and_dense_; dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_; dest_impl->reserved_ = src_impl->reserved_; dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); - dest_impl->storage_access_should_throw_ = src_impl->storage_access_should_throw_; + dest_impl->storage_access_should_throw_ = + src_impl->storage_access_should_throw_; if (src_impl->named_tensor_meta_ != nullptr) { dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone(); } @@ -454,9 +497,11 @@ void TensorImpl::copy_tensor_metadata( TensorImpl* dest_impl, const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) { - copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change); + copy_tensor_metadata_except_version_counter( + src_impl, dest_impl, allow_tensor_metadata_change); // TODO: In the ideal end state, it's okay to set disabled version_counter - // on inference tensor since it's a no-op. This requires refactor on call sites. + // on inference tensor since it's a no-op. This requires refactor on call + // sites. if (!dest_impl->is_inference_tensor()) { dest_impl->set_version_counter(version_counter); } @@ -467,7 +512,8 @@ void TensorImpl::copy_tensor_metadata( TensorImpl* dest_impl, c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) { - copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change); + copy_tensor_metadata_except_version_counter( + src_impl, dest_impl, allow_tensor_metadata_change); if (!dest_impl->is_inference_tensor()) { dest_impl->set_version_counter(std::move(version_counter)); } @@ -478,13 +524,15 @@ namespace impl { namespace { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) AutogradMetaFactory* meta_factory = nullptr; -} +} // namespace void SetAutogradMetaFactory(AutogradMetaFactory* factory) { meta_factory = factory; } AutogradMetaFactory* GetAutogradMetaFactory() { - TORCH_CHECK(meta_factory, "Support for autograd has not been loaded; have you linked against libtorch.so?") + TORCH_CHECK( + meta_factory, + "Support for autograd has not been loaded; have you linked against libtorch.so?") return meta_factory; } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 3eeb1efcb1e..b9e704d85ab 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -8,17 +8,17 @@ #include #include #include -#include -#include #include #include #include #include -#include +#include +#include #include #include #include #include +#include #include // A global boolean variable to control whether we free memory when a Tensor @@ -36,7 +36,6 @@ C10_DECLARE_bool(caffe2_keep_on_shrink); // respect caffe2_keep_on_shrink. C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory); - namespace at { class Tensor; } @@ -135,12 +134,19 @@ struct C10_API PlacementDeleteContext { struct TensorImpl; struct C10_API AutogradMetaInterface { - virtual void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) = 0; + virtual void set_requires_grad( + bool requires_grad, + at::TensorImpl* self_impl) = 0; virtual bool requires_grad() const = 0; virtual at::Tensor& mutable_grad() = 0; virtual const at::Tensor& grad() const = 0; - virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0; - virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0; + virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) + const = 0; + virtual void set_fw_grad( + const at::Tensor& new_grad, + const at::Tensor& self, + uint64_t level, + bool is_inplace_op) = 0; virtual ~AutogradMetaInterface(); }; @@ -172,25 +178,24 @@ struct C10_API AutogradMetaFactoryRegisterer { } // namespace impl struct C10_API NamedTensorMetaInterface { - virtual ~NamedTensorMetaInterface() {}; + virtual ~NamedTensorMetaInterface(){}; virtual std::unique_ptr clone() const { TORCH_INTERNAL_ASSERT( - false, - "Not implemented: NamedTensorMetaInterface::clone"); + false, "Not implemented: NamedTensorMetaInterface::clone"); }; virtual int64_t slow_dim() const { TORCH_INTERNAL_ASSERT( - false, - "Not implemented: NamedTensorMetaInterface::slow_dim"); + false, "Not implemented: NamedTensorMetaInterface::slow_dim"); }; }; // NOTE [ Version Counter Sharing ] // -// Every Tensor has a version counter. Version counters are incremented whenever the -// data or size of a tensor changes through in-place Variable operations. Version -// counters are used to detect modifications to saved variables which would result in -// incorrect gradient calculations. Version counters may be shared between Variables: +// Every Tensor has a version counter. Version counters are incremented whenever +// the data or size of a tensor changes through in-place Variable operations. +// Version counters are used to detect modifications to saved variables which +// would result in incorrect gradient calculations. Version counters may be +// shared between Variables: // // 1. A view shares the version counter of the base Variable, // 2. `x.detach()` shares the version counter of `x`, @@ -198,27 +203,32 @@ struct C10_API NamedTensorMetaInterface { // // Version counters are not shared in these scenarios: // -// 1. When we replace a `Variable`'s underlying `Tensor` by calling `set_data(...)`, +// 1. When we replace a `Variable`'s underlying `Tensor` by calling +// `set_data(...)`, // 2. `x.data` does not share the version counter of `x`. (See discussion at // https://github.com/pytorch/pytorch/issues/5396) // -// Question: Why do we put the version counter in TensorImpl instead of AutogradMeta? +// Question: Why do we put the version counter in TensorImpl instead of +// AutogradMeta? // -// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta when -// its `requires_grad_` is false, but when we use this tensor in the forward pass of -// a function that requires saving this tensor for backward, we need to keep track of -// this tensor's version to make sure it's always valid in the autograd graph. +// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta +// when its `requires_grad_` is false, but when we use this tensor in the +// forward pass of a function that requires saving this tensor for backward, we +// need to keep track of this tensor's version to make sure it's always valid in +// the autograd graph. // -// To achieve this goal, we put the version counter in TensorImpl instead of AutogradMeta, -// and have it always be available. This allows us to have the optimization of not -// carrying AutogradMeta when a tensor doesn't require gradient. +// To achieve this goal, we put the version counter in TensorImpl instead of +// AutogradMeta, and have it always be available. This allows us to have the +// optimization of not carrying AutogradMeta when a tensor doesn't require +// gradient. // -// A hypothetical alternative way to achieve this goal is to initialize AutogradMeta and -// create the version counter for the non-requires-grad tensor only when it's saved for -// backward. However, since saving a tensor for backward happens in the forward pass, and -// our invariant is that forward pass needs to be thread-safe, lazy-initializing AutogradMeta -// when saving a tensor can introduce race conditions when we are running the forward -// pass in multi-thread scenarios, thus making the forward pass not thread-safe anymore, +// A hypothetical alternative way to achieve this goal is to initialize +// AutogradMeta and create the version counter for the non-requires-grad tensor +// only when it's saved for backward. However, since saving a tensor for +// backward happens in the forward pass, and our invariant is that forward pass +// needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor +// can introduce race conditions when we are running the forward pass in +// multi-thread scenarios, thus making the forward pass not thread-safe anymore, // which breaks the invariant. struct C10_API VariableVersion { private: @@ -238,7 +248,8 @@ struct C10_API VariableVersion { // Example use cases are: // - Inference tensors don't track version counter, so they'll just always // have disbaled VariableVersion. - // - In SavedVariable class we override version_counter_ inside its construtor + // - In SavedVariable class we override version_counter_ inside its + // construtor // so that we can use the cheap constructor there. enum Disabled { DISABLED }; // It's okay to return true even for inference tensor which @@ -254,7 +265,7 @@ struct C10_API VariableVersion { // https://cplusplus.github.io/LWG/issue2334. VariableVersion(uint32_t version) : version_counter_(c10::make_intrusive(version)) {} - VariableVersion(Disabled=DISABLED) {} + VariableVersion(Disabled = DISABLED) {} bool enabled() const { return version_counter_; @@ -283,10 +294,11 @@ struct C10_API VariableVersion { // - e.g. inference_tensor.add_(normal_tensor) void bump() { // TODO: Replace the link to the documentation once it's available. - TORCH_CHECK(version_counter_ || InferenceMode::is_enabled(), - "Inplace update to inference tensor outside InferenceMode is not allowed." - "You can make a clone to get a normal tensor before doing inplace update." - "See https://github.com/pytorch/rfcs/pull/17 for more details."); + TORCH_CHECK( + version_counter_ || InferenceMode::is_enabled(), + "Inplace update to inference tensor outside InferenceMode is not allowed." + "You can make a clone to get a normal tensor before doing inplace update." + "See https://github.com/pytorch/rfcs/pull/17 for more details."); if (version_counter_) { ++version_counter_->version_; } @@ -295,7 +307,8 @@ struct C10_API VariableVersion { // Inference tensor doesn't have version counter so it shouldn't be // accessed. uint32_t current_version() const { - TORCH_CHECK(version_counter_, "Inference tensor do not track version counter."); + TORCH_CHECK( + version_counter_, "Inference tensor do not track version counter."); return version_counter_->version_; } }; @@ -414,7 +427,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional device_opt); + TensorImpl( + DispatchKeySet, + const caffe2::TypeMeta data_type, + c10::optional device_opt); // Legacy constructors so I don't have to go update call sites. // TODO: When Variable is added, delete these constructors @@ -426,15 +442,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { std::move(storage), DispatchKeySet(dispatch_key), data_type) {} - TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta data_type, c10::optional device_opt) - : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} + TensorImpl( + DispatchKey dispatch_key, + const caffe2::TypeMeta data_type, + c10::optional device_opt) + : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} private: // This constructor is private, because the data_type is redundant with // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional); + TensorImpl( + Storage&& storage, + DispatchKeySet, + const caffe2::TypeMeta data_type, + c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -454,7 +477,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * all of the DispatchKeys that this Tensor identifies as. This is the * information used to dispatch operations on this tensor. */ - DispatchKeySet key_set() const { return key_set_; } + DispatchKeySet key_set() const { + return key_set_; + } /** * Return a reference to the sizes of this tensor. This reference remains @@ -466,7 +491,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sizes_and_strides_.sizes_arrayref(); } #else - ; + ; #endif /** @@ -485,19 +510,21 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sizes_and_strides_.size(); } #else - ; + ; #endif /** * True if this tensor has storage. See storage() for details. */ #ifdef DEBUG -// Allow subclasses to check that their storage_ is never getting set in debug builds. + // Allow subclasses to check that their storage_ is never getting set in debug + // builds. virtual #else TENSORIMPL_MAYBE_VIRTUAL #endif - bool has_storage() const + bool + has_storage() const // NOTE: we devirtualize this because it arguably shouldn't be an // error just to ask subclasses if they have storage. // This used to throw for most subclasses, but OpaqueTensorImpl @@ -508,7 +535,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return storage_; } #else - ; + ; #endif /** @@ -556,15 +583,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * backward compatibility. See `set_has_contiguity_policy` and * `is_contiguous_custom` for the encouraged customization point. */ - TENSORIMPL_MAYBE_VIRTUAL bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { - if (C10_UNLIKELY(has_contiguity_ != static_cast(HasContiguityPolicy::Default))) { + TENSORIMPL_MAYBE_VIRTUAL bool is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY( + has_contiguity_ != + static_cast(HasContiguityPolicy::Default))) { return is_contiguous_nondefault_policy_impl(memory_format); } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(compute_contiguous() == is_contiguous_); if (memory_format == at::MemoryFormat::ChannelsLast) { return is_channels_last_contiguous_; - } - else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { return is_channels_last_3d_contiguous_; } return is_contiguous_; @@ -583,33 +612,38 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { public: bool is_sparse() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::SparseCPU) || key_set_.has(DispatchKey::SparseCUDA) || key_set_.has(DispatchKey::SparseHIP) || key_set_.has(DispatchKey::SparseXPU); } - // Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR format. + // Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR + // format. bool is_sparse_csr() const { return key_set_.has(DispatchKey::SparseCsrCPU) || - key_set_.has(DispatchKey::SparseCsrCUDA); + key_set_.has(DispatchKey::SparseCsrCUDA); } bool is_quantized() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::QuantizedCPU) || key_set_.has(DispatchKey::QuantizedCUDA) || key_set_.has(DispatchKey::QuantizedXPU); } bool is_meta() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::Meta); } bool is_cpu() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::CPU) || key_set_.has(DispatchKey::SparseCPU) || key_set_.has(DispatchKey::SparseCsrCPU) || @@ -618,7 +652,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } bool is_cuda() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::CUDA) || key_set_.has(DispatchKey::SparseCUDA) || key_set_.has(DispatchKey::SparseCsrCUDA) || @@ -638,9 +673,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } bool is_hip() const { - // NB: This method is not virtual and avoid dispatches for performance reasons. + // NB: This method is not virtual and avoid dispatches for performance + // reasons. return key_set_.has(DispatchKey::HIP) || - key_set_.has(DispatchKey::SparseHIP); + key_set_.has(DispatchKey::SparseHIP); } bool is_mkldnn() const { @@ -659,7 +695,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return key_set_.has(DispatchKey::MLC); } - // TODO: remove this once we don't automatically enabled Autograd dispatch keys + // TODO: remove this once we don't automatically enabled Autograd dispatch + // keys // in TensorImpl constructor. // DON'T USE THIS API!! It's only created for testing purpose in // file aten/src/ATen/core/boxing/impl/test_helpers.h @@ -673,23 +710,20 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_inference_tensor() { bool no_InplaceOrView = !key_set_.has(c10::DispatchKey::InplaceOrView); bool no_Autograd = (key_set_ & c10::autograd_dispatch_keyset).empty(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(no_InplaceOrView == no_Autograd, - "InplaceOrView and Autograd keys must be on/off at the same time."); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + no_InplaceOrView == no_Autograd, + "InplaceOrView and Autograd keys must be on/off at the same time."); return no_InplaceOrView && no_Autograd; } int64_t get_device() const { - TORCH_CHECK( - device_opt_.has_value(), - "tensor does not have a device"); + TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device"); // See NOTE [c10::optional operator usage in CUDA] return (*device_opt_).index(); } Device device() const { - TORCH_CHECK( - device_opt_.has_value(), - "tensor does not have a device"); + TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device"); // See NOTE [c10::optional operator usage in CUDA] return *device_opt_; } @@ -795,9 +829,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * - "level" allows to specify the level of forward AD nesting for which the * gradient should be returned. Note that since levels are not fully * supported yet, this argument should be 0. See documentation for - * torch::autograd::enter_dual_level for more details about forward AD nesting. - * - "self" should represent the Tensor whose forward grad is accessed. It is - * required when dealing with view. + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. */ const at::Tensor& _fw_grad(uint64_t level, const at::Tensor& self) const; @@ -808,18 +843,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * This is an internal API that should never be used by end users. * * The API is as follows: - * - "new_grad" is a Tensor containing the new value of the gradient that should - * be set - * - "self" should represent the Tensor whose forward grad is accessed. It is - * required when dealing with view. + * - "new_grad" is a Tensor containing the new value of the gradient that + * should be set + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. * - "level" allows to specify the level of forward AD nesting for which the * gradient should be set. Note that since levels are not fully supported - * yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level - * for more details about forward AD nesting. - * - "is_inplace_op" is a boolean flag that tells if this gradient was generated - * by an inplace operation or an out of place one. This allows better error checking. + * yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "is_inplace_op" is a boolean flag that tells if this gradient was + * generated by an inplace operation or an out of place one. This allows + * better error checking. */ - void _set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op); + void _set_fw_grad( + const at::Tensor& new_grad, + const at::Tensor& self, + uint64_t level, + bool is_inplace_op); /** * Return a typed data pointer to the actual data which this tensor refers to. @@ -835,7 +876,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * for you; this class is available from 'Tensor'. */ template - inline T * data() const { + inline T* data() const { TORCH_CHECK( data_type_.Match(), "Tensor type mismatch, caller expects elements to be ", @@ -852,8 +893,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * check has_storage() and storage_initialized(). */ template - inline T * data_ptr_impl() const { - TORCH_CHECK(has_storage(), + inline T* data_ptr_impl() const { + TORCH_CHECK( + has_storage(), "Cannot access data pointer of Tensor that doesn't have storage"); TORCH_CHECK( storage_initialized(), @@ -875,9 +917,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * can be validly read from this tensor. */ inline void* data() const { - TORCH_CHECK(has_storage(), + TORCH_CHECK( + has_storage(), "Cannot access data pointer of Tensor that doesn't have storage"); - TORCH_CHECK(dtype_initialized(), + TORCH_CHECK( + dtype_initialized(), "Cannot access data pointer of Tensor that doesn't have initialized dtype " "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); return static_cast( @@ -890,7 +934,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * that all invariants required by data() are upheld here. */ template - inline T * unsafe_data() const { + inline T* unsafe_data() const { return storage_.unsafe_data() + storage_offset_; } @@ -906,7 +950,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Return the size of a single element of this tensor in bytes. */ size_t itemsize() const { - TORCH_CHECK(dtype_initialized(), + TORCH_CHECK( + dtype_initialized(), "Cannot report itemsize of Tensor that doesn't have initialized dtype " "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); return data_type_.itemsize(); @@ -952,7 +997,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * which is harder to misuse. */ virtual void set_size(int64_t dim, int64_t new_size) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_size ", + err_msg_tensor_metadata_change_not_allowed); sizes_and_strides_.size_at(dim) = new_size; refresh_numel(); refresh_contiguous(); @@ -965,7 +1013,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * which is harder to misuse. */ virtual void set_stride(int64_t dim, int64_t new_stride) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_stride ", + err_msg_tensor_metadata_change_not_allowed); sizes_and_strides_.stride_at_unchecked(dim) = new_stride; refresh_contiguous(); } @@ -978,7 +1029,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * (and resizing if necessary.) */ virtual void set_storage_offset(int64_t storage_offset) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_storage_offset ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage_offset ", + err_msg_tensor_metadata_change_not_allowed); storage_offset_ = storage_offset; } @@ -990,7 +1044,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * this is the responsibility of the caller */ void set_sizes_contiguous(IntArrayRef new_size) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_contiguous ", + err_msg_tensor_metadata_change_not_allowed); sizes_and_strides_.set_sizes(new_size); @@ -1006,7 +1063,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * this is the responsibility of the caller */ void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_and_strides ", + err_msg_tensor_metadata_change_not_allowed); TORCH_CHECK( new_size.size() == new_stride.size(), "dimensionality of sizes (", @@ -1019,7 +1079,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { sizes_and_strides_.set_sizes(new_size); if (new_dim > 0) { - for (size_t dim = new_dim - 1; ; dim--) { + for (size_t dim = new_dim - 1;; dim--) { if (new_stride[dim] >= 0) { sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; } else { @@ -1031,11 +1091,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } else { // Keep stride monotonically increasing to match NumPy. sizes_and_strides_.stride_at_unchecked(dim) = - std::max(sizes_and_strides_.size_at_unchecked(dim + 1), 1) * - sizes_and_strides_.stride_at_unchecked(dim + 1); + std::max( + sizes_and_strides_.size_at_unchecked(dim + 1), 1) * + sizes_and_strides_.stride_at_unchecked(dim + 1); } } - if (dim == 0) break; + if (dim == 0) + break; } } @@ -1054,16 +1116,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual int64_t stride(int64_t d) const; /** - * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset). - * See NOTE [ Metadata Change for a Detached Tensor ] for details. + * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. */ void set_allow_tensor_metadata_change(bool value) { allow_tensor_metadata_change_ = value; } /** - * True if a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset). - * See NOTE [ Metadata Change for a Detached Tensor ] for details. + * True if a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. */ bool allow_tensor_metadata_change() const { return allow_tensor_metadata_change_; @@ -1072,7 +1136,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set the pointer to autograd metadata. */ - void set_autograd_meta(std::unique_ptr autograd_meta); + void set_autograd_meta( + std::unique_ptr autograd_meta); /** * Return the pointer to autograd metadata. May return nullptr if the @@ -1083,7 +1148,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set the pointer to named tensor metadata. */ - void set_named_tensor_meta(std::unique_ptr named_tensor_meta) { + void set_named_tensor_meta( + std::unique_ptr named_tensor_meta) { TORCH_WARN_ONCE( "Named tensors and all their associated APIs are an experimental feature ", "and subject to change. Please do not use them for anything important ", @@ -1116,34 +1182,44 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return named_tensor_meta_ != nullptr; } - // NOTE [ TensorImpl Shallow-Copying ] // - // TensorImpl shallow-copying is used when we want to have two Variables share the same tensor metadata - // (e.g. sizes / strides / storage pointer / storage_offset), but each with a different autograd history. - // Example call sites: + // TensorImpl shallow-copying is used when we want to have two Variables share + // the same tensor metadata (e.g. sizes / strides / storage pointer / + // storage_offset), but each with a different autograd history. Example call + // sites: // - // 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create `var_detached` that shares - // the same tensor metadata with `var`, but with a completely new autograd history. - // 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor metadata from - // `tensor` into `var`, while keeping `var`'s original AutogradMeta. + // 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create + // `var_detached` that shares the same tensor metadata with `var`, but with a + // completely new autograd history. + // 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor + // metadata from `tensor` into `var`, while keeping `var`'s original + // AutogradMeta. // - // Functions that shallow-copy a TensorImpl (such as `shallow_copy_and_detach()` / `shallow_copy_from()` / - // `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes / strides / storage pointer / - // storage_offset) by value. However, the following fields are not copied: + // Functions that shallow-copy a TensorImpl (such as + // `shallow_copy_and_detach()` / `shallow_copy_from()` / + // `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes / + // strides / storage pointer / storage_offset) by value. However, the + // following fields are not copied: // // 1. the AutogradMeta pointer, because it is unique for each Variable. - // 2. the version counter, because the destination TensorImpl's version counter is either set to the - // passed-in `version_counter` (in `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept - // intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for details. + // 2. the version counter, because the destination TensorImpl's version + // counter is either set to the passed-in `version_counter` (in + // `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept + // intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for + // details. // - // In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in `allow_tensor_metadata_change` - // determines whether the TensorImpl shallow-copy allows changes to its metadata (e.g. sizes / strides / - // storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for details. + // In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in + // `allow_tensor_metadata_change` determines whether the TensorImpl + // shallow-copy allows changes to its metadata (e.g. sizes / strides / storage + // / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for + // details. // - // In `shallow_copy_from()`, we don't check the destination TensorImpl's `allow_tensor_metadata_change_`, - // because `shallow_copy_from()` is used for implementing functions such as `var.set_data(tensor)`, which - // changes `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to be ignored. + // In `shallow_copy_from()`, we don't check the destination TensorImpl's + // `allow_tensor_metadata_change_`, because `shallow_copy_from()` is used for + // implementing functions such as `var.set_data(tensor)`, which changes + // `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to + // be ignored. /** * One TensorImpl can be copied to another TensorImpl if they have the same @@ -1161,7 +1237,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) || ts.has(DispatchKey::SparseXPU); }; - return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); + return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || + (is_sparse(key_set_) && is_sparse(from)); } /** @@ -1187,32 +1264,32 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Shallow-copies data from another TensorImpl into this TensorImpl. * - * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`, - * see NOTE [ TensorImpl Shallow-Copying ]. + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. */ virtual void shallow_copy_from(const c10::intrusive_ptr& impl) { copy_tensor_metadata( - /*src_impl=*/impl.get(), - /*dest_impl=*/this, - /*version_counter=*/version_counter(), - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + /*src_impl=*/impl.get(), + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); refresh_numel(); refresh_contiguous(); } // Inference tensor doesn't have version counter, // set_version_counter is no-op for them. - void set_version_counter( - const c10::VariableVersion& version_counter) { - TORCH_CHECK(!(is_inference_tensor() && version_counter.enabled()), - "Cannot set version_counter for inference tensor"); + void set_version_counter(const c10::VariableVersion& version_counter) { + TORCH_CHECK( + !(is_inference_tensor() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); version_counter_ = version_counter; } - void set_version_counter( - c10::VariableVersion&& version_counter) { - TORCH_CHECK(!(is_inference_tensor() && version_counter.enabled()), - "Cannot set version_counter for inference tensor"); + void set_version_counter(c10::VariableVersion&& version_counter) { + TORCH_CHECK( + !(is_inference_tensor() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); version_counter_ = std::move(version_counter); } @@ -1241,14 +1318,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } public: - /** * The device type of a Tensor, e.g., DeviceType::CPU or DeviceType::CUDA. */ DeviceType device_type() const { // TODO: A useful internal assert would be to show that device_opt_ is null // only if you are an undefined tensor - TORCH_CHECK(device_opt_.has_value(), "device_type cannot be run on undefined Tensor"); + TORCH_CHECK( + device_opt_.has_value(), + "device_type cannot be run on undefined Tensor"); // See NOTE [c10::optional operator usage in CUDA] return (*device_opt_).type(); } @@ -1271,29 +1349,33 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { is_contiguous_, "Right now Extend is only supported for contiguous Tensor."); using SizesVector = SmallVector; - SizesVector newDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + SizesVector newDims( + sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newDims[0] += num; if (!storage_.data()) { Resize(newDims); return; } - const auto newNumel = c10::multiply_integers(newDims.begin(), newDims.end()); + const auto newNumel = + c10::multiply_integers(newDims.begin(), newDims.end()); if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; return; } - SizesVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + SizesVector newCapacity( + sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = std::max( - newDims[0], static_cast(std::ceil(sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); + newDims[0], + static_cast(std::ceil( + sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; Resize(newCapacity); auto* newData = raw_mutable_data(data_type_); if (data_type_.copy()) { TORCH_CHECK( - device_type() == DeviceType::CPU, - "non-POD types work only on CPU"); + device_type() == DeviceType::CPU, "non-POD types work only on CPU"); data_type_.copy()(oldData.get(), newData, oldSize); } else { // The following copy uses the current (thread local) stream for copying @@ -1332,7 +1414,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { TORCH_CHECK( storage_.unique(), "Can't call ReserveSpace on shared storage."); // TODO: eliminate newCapacity. - SmallVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + SmallVector newCapacity( + sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = outer_dim; auto newNumel = c10::multiply_integers(newCapacity); if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { @@ -1341,7 +1424,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Old data is discarded storage_.data_ptr().clear(); auto oldSize = numel_; - SmallVector oldDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + SmallVector oldDims( + sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); Resize(newCapacity); // Allocate new memory but don't copy over the data raw_mutable_data(data_type_); @@ -1416,7 +1500,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { storage_offset_ = 0; } - /** + /** * @brief Shares the data with another tensor. * * To share data between two tensors, the sizes of the two tensors must be @@ -1441,10 +1525,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // know what to share yet. // TODO: Add the assert after all uninitialized states are eliminated // TORCH_CHECK(src.dtype_initialized(), - // "Source tensor don't have a data type (did you call mutable_data on the tensor?)"); + // "Source tensor don't have a data type (did you call + // mutable_data on the tensor?)"); if (!src.dtype_initialized()) { - C10_LOG_EVERY_MS(WARNING, 1000) << - "Source tensor don't have a data type (did you call mutable_data on the tensor?)"; + C10_LOG_EVERY_MS(WARNING, 1000) + << "Source tensor don't have a data type (did you call mutable_data on the tensor?)"; } TORCH_CHECK( src.storage_initialized(), @@ -1504,7 +1589,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { inline void* raw_mutable_data(const caffe2::TypeMeta meta) { // For 0-size tensors it's fine to return any pointer (including nullptr) if (data_type_ == meta && storage_initialized()) { - return static_cast(static_cast(storage_.data()) + storage_offset_ * meta.itemsize()); + return static_cast( + static_cast(storage_.data()) + + storage_offset_ * meta.itemsize()); } else { bool had_special_dtor = data_type_.placementDelete() != nullptr; storage_offset_ = 0; @@ -1517,7 +1604,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (numel_ == 0 || (meta.placementNew() == nullptr && !had_special_dtor && (storage_.nbytes() >= (numel_ * data_type_.itemsize())))) { - TORCH_INTERNAL_ASSERT(storage_offset_ == 0); // because we just reallocated + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated return storage_.data(); } const Allocator* allocator = storage_.allocator(); @@ -1544,7 +1632,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { allocator->allocate(numel_ * data_type_.itemsize())); } storage_.set_nbytes(numel_ * data_type_.itemsize()); - TORCH_INTERNAL_ASSERT(storage_offset_ == 0); // because we just reallocated + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated device_opt_ = storage_.device(); return storage_.data(); } @@ -1574,7 +1663,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * storage UNINITIALIZED after a Resize() or FreeMemory() */ bool storage_initialized() const { - TORCH_CHECK(has_storage(), "cannot call storage_initialized on tensor that does not have storage"); + TORCH_CHECK( + has_storage(), + "cannot call storage_initialized on tensor that does not have storage"); return storage_.data() || numel_ == 0; } @@ -1588,7 +1679,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } void set_storage_keep_dtype(at::Storage storage) { - TORCH_CHECK(allow_tensor_metadata_change(), "set_storage ", err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage ", + err_msg_tensor_metadata_change_not_allowed); storage_ = std::move(storage); device_opt_ = storage_.device(); } @@ -1603,15 +1697,16 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set the strides of the tensor to match memory_format * - * WARNING: This function doesn't rearrange data and assumes tensor is a memory - * contiguous + * WARNING: This function doesn't rearrange data and assumes tensor is a + * memory contiguous */ void empty_tensor_restride(MemoryFormat memory_format) { - #ifdef DEBUG - TORCH_INTERNAL_ASSERT(compute_numel() == numel_, +#ifdef DEBUG + TORCH_INTERNAL_ASSERT( + compute_numel() == numel_, "If you are seeing this error, that means empty_tensor_restride was " "called before setting correct numel"); - #endif +#endif switch (memory_format) { case MemoryFormat::Contiguous: { // dim_ is a virtual call, don't repeat it @@ -1621,15 +1716,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { const auto last_idx = dim_ - 1; sizes_and_strides_.stride_at_unchecked(last_idx) = 1; for (auto i = last_idx - 1; i >= 0; --i) { - sizes_and_strides_.stride_at_unchecked(i) = sizes_and_strides_.stride_at_unchecked(i + 1) * std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1); + sizes_and_strides_.stride_at_unchecked(i) = + sizes_and_strides_.stride_at_unchecked(i + 1) * + std::max( + sizes_and_strides_.size_at_unchecked(i + 1), 1); } } break; } case MemoryFormat::ChannelsLast: { TORCH_CHECK( - dim() == 4, - "required rank 4 tensor to use channels_last format"); + dim() == 4, "required rank 4 tensor to use channels_last format"); set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes())); break; } @@ -1663,7 +1760,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_non_overlapping_and_dense_; } -private: + private: void HandleResize(); // The Caffe2 Resize() method supports being called both as Resize({2,2}) as @@ -1719,7 +1816,11 @@ private: return SetDims(IntArrayRef{d0, d1, d2}); } - bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) { + bool SetDims( + const int64_t d0, + const int64_t d1, + const int64_t d2, + const int64_t d3) { return SetDims(IntArrayRef{d0, d1, d2, d3}); } @@ -1750,7 +1851,7 @@ private: bool compute_non_overlapping_and_dense() const; -protected: + protected: /** * Recompute the cached numel of a tensor. Call this if you modify sizes. */ @@ -1766,41 +1867,51 @@ protected: is_contiguous_ = compute_contiguous(); // Note: // Dim 0, 1, 2 will never be a channels last 2d/3d format - // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this point) - // Dim 4+ is possibly be a channels last 3d format (Dim 5 only at this point) + // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this + // point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at + // this point) switch (dim()) { case 4: is_channels_last_contiguous_ = compute_channels_last_contiguous_2d(); is_channels_last_3d_contiguous_ = false; is_channels_last_ = compute_strides_like_channels_last_2d(); is_channels_last_3d_ = false; - is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense(); + is_non_overlapping_and_dense_ = is_contiguous_ || + is_channels_last_contiguous_ || compute_non_overlapping_and_dense(); break; case 5: is_channels_last_contiguous_ = compute_channels_last_contiguous_2d(); - is_channels_last_3d_contiguous_ = !is_channels_last_contiguous_ && compute_channels_last_contiguous_3d(); - is_channels_last_ = !is_channels_last_3d_contiguous_ && compute_strides_like_channels_last_2d(); - is_channels_last_3d_ = !is_channels_last_ && compute_strides_like_channels_last_3d(); - is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || is_channels_last_3d_contiguous_|| compute_non_overlapping_and_dense(); + is_channels_last_3d_contiguous_ = !is_channels_last_contiguous_ && + compute_channels_last_contiguous_3d(); + is_channels_last_ = !is_channels_last_3d_contiguous_ && + compute_strides_like_channels_last_2d(); + is_channels_last_3d_ = + !is_channels_last_ && compute_strides_like_channels_last_3d(); + is_non_overlapping_and_dense_ = is_contiguous_ || + is_channels_last_contiguous_ || is_channels_last_3d_contiguous_ || + compute_non_overlapping_and_dense(); break; default: is_channels_last_contiguous_ = false; is_channels_last_3d_contiguous_ = false; - // is_channels_last_ and is_channels_last_3d_ are suggested memory_format. - // Being channels_last_contiguous doesn't necessarily mean the tensor is - // strided like channels_last: for strides on channel dimension could suggest - // desired memory_layout, but it doesn't affect memory storage + // is_channels_last_ and is_channels_last_3d_ are suggested + // memory_format. Being channels_last_contiguous doesn't necessarily + // mean the tensor is strided like channels_last: for strides on channel + // dimension could suggest desired memory_layout, but it doesn't affect + // memory storage is_channels_last_ = false; is_channels_last_3d_ = false; - is_non_overlapping_and_dense_ = is_contiguous_ || compute_non_overlapping_and_dense(); + is_non_overlapping_and_dense_ = + is_contiguous_ || compute_non_overlapping_and_dense(); } } /** - * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) - * from one TensorImpl to another TensorImpl. + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. * - * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. */ static void copy_tensor_metadata( const TensorImpl* src_impl, @@ -1809,10 +1920,11 @@ protected: bool allow_tensor_metadata_change); /** - * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) - * from one TensorImpl to another TensorImpl. + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. * - * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. */ static void copy_tensor_metadata( const TensorImpl* src_impl, @@ -1820,25 +1932,25 @@ protected: c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change); -private: + private: static void copy_tensor_metadata_except_version_counter( const TensorImpl* src_impl, TensorImpl* dest_impl, bool allow_tensor_metadata_change); -protected: + protected: // Error message to show when the user tries to change tensor metadata on // Tensor created from .data or .detach(). // // See NOTE [ Metadata Change for a Detached Tensor ] for details. - static const char * const err_msg_tensor_metadata_change_not_allowed; + static const char* const err_msg_tensor_metadata_change_not_allowed; -public: + public: void set_storage_access_should_throw() { storage_access_should_throw_ = true; } -protected: + protected: // Policy for adjusting the behavior of is_contiguous(). Allows // subclass customization while still being able to inline // is_contiguous() in the common case. @@ -1859,11 +1971,10 @@ protected: Storage storage_; -private: - // This pointer points to an AutogradMeta struct that stores autograd-specific fields - // (such as grad_ / grad_fn_ / grad_accumulator_). - // This pointer always has unique ownership (meaning only one TensorImpl can own it - // at a time). + private: + // This pointer points to an AutogradMeta struct that stores autograd-specific + // fields (such as grad_ / grad_fn_ / grad_accumulator_). This pointer always + // has unique ownership (meaning only one TensorImpl can own it at a time). // // autograd_meta_ can be nullptr, as an optimization. When this occurs, it is // equivalent to having an autograd_meta_ pointing to a default constructed @@ -1886,7 +1997,7 @@ private: // std::unique_ptr autograd_meta_ = nullptr; -protected: + protected: std::unique_ptr named_tensor_meta_ = nullptr; c10::VariableVersion version_counter_; @@ -1941,7 +2052,8 @@ protected: // Tensor is a subclass that does not permit storage access. bool storage_access_should_throw_ = false; - // default member initializers for bit-fields only available with -std=c++2a or -std=gnu++2a + // default member initializers for bit-fields only available with -std=c++2a + // or -std=gnu++2a inline void init_bitfields() { is_contiguous_ = true; has_contiguity_ = static_cast(HasContiguityPolicy::Default); @@ -1967,18 +2079,18 @@ protected: bool is_channels_last_contiguous_ : 1; // Tensor is stored in the channels last 3d memory format, when dimensions - // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< N-strides) - // (If size of any dimension is equal to 1, this dimension strides value - // is not taken into account). + // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< + // N-strides) (If size of any dimension is equal to 1, this dimension strides + // value is not taken into account). bool is_channels_last_3d_ : 1; // Channels last 3d contiguous tensor is channel last 3d tensor which occupies // contiguous memory block. bool is_channels_last_3d_contiguous_ : 1; - // Dense tensor is the tensor that store values in a contiguous block of memory. - // Non-overlapping tensor is the tensor in which elements occupy individual - // non-repetitive memory. + // Dense tensor is the tensor that store values in a contiguous block of + // memory. Non-overlapping tensor is the tensor in which elements occupy + // individual non-repetitive memory. bool is_non_overlapping_and_dense_ : 1; bool is_wrapped_number_ : 1; @@ -2010,7 +2122,8 @@ protected: // does NOT include Autograd (historically, it did, but // not anymore!) // - // INVARIANT: named_tensor_meta_ != nullptr <==> key_set_.has(DispatchKey::Named) + // INVARIANT: named_tensor_meta_ != nullptr <==> + // key_set_.has(DispatchKey::Named) DispatchKeySet key_set_; }; @@ -2063,8 +2176,9 @@ protected: // data type, device, is_contiguous, storage_access_should_throw_, bitfields // DispatchKeySet // -static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 23, - "You changed the size of TensorImpl on 64-bit arch." - "See Note [TensorImpl size constraints] on how to proceed."); +static_assert( + sizeof(void*) != sizeof(int64_t) || // if 64-bit... + sizeof(TensorImpl) == sizeof(int64_t) * 23, + "You changed the size of TensorImpl on 64-bit arch." + "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/TensorOptions.cpp b/c10/core/TensorOptions.cpp index e1fdc5b39de..a9380f51392 100644 --- a/c10/core/TensorOptions.cpp +++ b/c10/core/TensorOptions.cpp @@ -17,18 +17,20 @@ namespace c10 { // internal state and what its getters will return. std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) { - - auto print = [&](const char *label, auto prop, bool has_prop) { + auto print = [&](const char* label, auto prop, bool has_prop) { stream << label << std::boolalpha << prop << (has_prop ? "" : " (default)"); }; print("TensorOptions(dtype=", options.dtype(), options.has_dtype()); print(", device=", options.device(), options.has_device()); print(", layout=", options.layout(), options.has_layout()); - print(", requires_grad=", options.requires_grad(), options.has_requires_grad()); - print(", pinned_memory=", options.pinned_memory(), options.has_pinned_memory()); + print( + ", requires_grad=", options.requires_grad(), options.has_requires_grad()); + print( + ", pinned_memory=", options.pinned_memory(), options.has_pinned_memory()); - // note: default-supplying memory_format() getter not provided; no canonical default + // note: default-supplying memory_format() getter not provided; no canonical + // default stream << ", memory_format="; if (options.has_memory_format()) { stream << *options.memory_format_opt(); diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 4f750de0fff..06ddb2f008a 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -1,17 +1,17 @@ #pragma once -#include #include +#include +#include +#include #include +#include #include #include -#include -#include -#include -#include -#include #include +#include +#include #include #include @@ -19,14 +19,18 @@ namespace c10 { -DispatchKey computeDispatchKey(c10::optional dtype, c10::optional layout, c10::optional device); +DispatchKey computeDispatchKey( + c10::optional dtype, + c10::optional layout, + c10::optional device); inline ScalarType dtype_or_default(c10::optional dtype) { - return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();}); + return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); }); } -inline caffe2::TypeMeta dtype_or_default(c10::optional dtype) { - return value_or_else(dtype, [] {return get_default_dtype();}); +inline caffe2::TypeMeta dtype_or_default( + c10::optional dtype) { + return value_or_else(dtype, [] { return get_default_dtype(); }); } inline Layout layout_or_default(c10::optional layout) { @@ -34,7 +38,7 @@ inline Layout layout_or_default(c10::optional layout) { } inline Device device_or_default(c10::optional device) { - return value_or_else(device, [] {return Device(kCPU);}); + return value_or_else(device, [] { return Device(kCPU); }); } inline bool pinned_memory_or_default(c10::optional pinned_memory) { @@ -65,7 +69,8 @@ inline bool pinned_memory_or_default(c10::optional pinned_memory) { /// at::dtype(at::kInt) /// /// Additionally, anywhere a TensorOptions is expected, you can directly -/// pass at::kCUDA / at::kInt, and it will implicitly convert to a TensorOptions. +/// pass at::kCUDA / at::kInt, and it will implicitly convert to a +/// TensorOptions. /// /// Here are some recommended ways to create a 2x2 tensor of zeros /// with certain properties. These all *implicitly* make use of @@ -108,7 +113,8 @@ inline bool pinned_memory_or_default(c10::optional pinned_memory) { /// } /// /// template ::value>> +/// typename = std::enable_if_t::value>> /// /* implicit */ TensorOptions(Args&&... args) /// : TensorOptions(Device(std::forward(args)...)) {} /// @@ -121,20 +127,21 @@ inline bool pinned_memory_or_default(c10::optional pinned_memory) { /// To get around this, we templatize the `Device` constructor. Since overload /// resolution is done before template resolution, our problem is solved. -DispatchKey computeDispatchKey(optional dtype, optional layout, optional device); - +DispatchKey computeDispatchKey( + optional dtype, + optional layout, + optional device); struct C10_API TensorOptions { TensorOptions() - : requires_grad_(false) - , pinned_memory_(false) - , has_device_(false) - , has_dtype_(false) - , has_layout_(false) - , has_requires_grad_(false) - , has_pinned_memory_(false) - , has_memory_format_(false) - {} + : requires_grad_(false), + pinned_memory_(false), + has_device_(false), + has_dtype_(false), + has_layout_(false), + has_requires_grad_(false), + has_pinned_memory_(false), + has_memory_format_(false) {} /// Constructs a `TensorOptions` object with the given layout. /* implicit */ TensorOptions(Layout layout) : TensorOptions() { @@ -143,8 +150,9 @@ struct C10_API TensorOptions { /// Constructs a `TensorOptions` object with the given device. /// See NOTE [ TensorOptions Constructors ] on why this is templatized. - template, Device>::value>> + template < + typename T, + typename = std::enable_if_t, Device>::value>> /* implicit */ TensorOptions(T&& device) : TensorOptions() { this->set_device(std::forward(device)); } @@ -157,10 +165,12 @@ struct C10_API TensorOptions { /// NB: Ideally we only allow implicit constructors here. But there is no easy /// way to detect them. So we have this one that allows explicit /// constructors too. - template ::value>> - /* implicit */ TensorOptions(Args&&... args) - : TensorOptions(Device(std::forward(args)...)) {} + template < + typename... Args, + typename = + std::enable_if_t::value>> + /* implicit */ TensorOptions(Args&&... args) + : TensorOptions(Device(std::forward(args)...)) {} /// Constructs a `TensorOptions` object with the given dtype. /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() { @@ -179,7 +189,8 @@ struct C10_API TensorOptions { /// Return a copy of `TensorOptions` with `device` set to the given one, or /// cleared if `device` is `nullopt`. - C10_NODISCARD TensorOptions device(c10::optional device) const noexcept { + C10_NODISCARD TensorOptions + device(c10::optional device) const noexcept { TensorOptions r = *this; r.set_device(device); return r; @@ -188,9 +199,10 @@ struct C10_API TensorOptions { /// Return a copy of `TensorOptions` with `device` set to the given one. /// (This overload ensures that variadic template c10::optional constructor /// for Device work correctly.) - template + template C10_NODISCARD TensorOptions device(Args&&... args) const noexcept { - return device(c10::optional(c10::in_place, std::forward(args)...)); + return device( + c10::optional(c10::in_place, std::forward(args)...)); } /// Return a copy of `TensorOptions`, but with device set to CUDA, and the @@ -198,19 +210,22 @@ struct C10_API TensorOptions { /// /// TODO: This function encourages bad behavior (assuming CUDA is /// the only device that matters). Get rid of it / rename it. - C10_NODISCARD TensorOptions device_index(int16_t device_index) const noexcept { + C10_NODISCARD TensorOptions + device_index(int16_t device_index) const noexcept { return device(Device::Type::CUDA, device_index); } /// Return a copy of `TensorOptions` with `dtype` set to the given one. - C10_NODISCARD TensorOptions dtype(c10::optional dtype) const noexcept { + C10_NODISCARD TensorOptions + dtype(c10::optional dtype) const noexcept { TensorOptions r = *this; r.set_dtype(dtype); return r; } // legacy function to support ScalarType - C10_NODISCARD TensorOptions dtype(c10::optional dtype) const noexcept { + C10_NODISCARD TensorOptions + dtype(c10::optional dtype) const noexcept { TensorOptions r = *this; r.set_dtype(dtype); return r; @@ -225,28 +240,32 @@ struct C10_API TensorOptions { } /// Sets the layout of the `TensorOptions`. - C10_NODISCARD TensorOptions layout(c10::optional layout) const noexcept { + C10_NODISCARD TensorOptions + layout(c10::optional layout) const noexcept { TensorOptions r = *this; r.set_layout(layout); return r; } /// Sets the `requires_grad` property of the `TensorOptions`. - C10_NODISCARD TensorOptions requires_grad(c10::optional requires_grad) const noexcept { + C10_NODISCARD TensorOptions + requires_grad(c10::optional requires_grad) const noexcept { TensorOptions r = *this; r.set_requires_grad(requires_grad); return r; } /// Sets the `pinned_memory` property on the `TensorOptions`. - C10_NODISCARD TensorOptions pinned_memory(c10::optional pinned_memory) const noexcept { + C10_NODISCARD TensorOptions + pinned_memory(c10::optional pinned_memory) const noexcept { TensorOptions r = *this; r.set_pinned_memory(pinned_memory); return r; } /// Sets the `memory_format` property on `TensorOptions`. - C10_NODISCARD TensorOptions memory_format(c10::optional memory_format) const noexcept { + C10_NODISCARD TensorOptions + memory_format(c10::optional memory_format) const noexcept { TensorOptions r = *this; r.set_memory_format(memory_format); return r; @@ -343,13 +362,15 @@ struct C10_API TensorOptions { // For compatibility with legacy tensor.type() comparisons bool type_equal(const TensorOptions& other) const { - return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); + return computeDispatchKey() == other.computeDispatchKey() && + typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); } /// Returns the `pinned_memory` property of the `TensorOptions`, or /// `c10::nullopt` if `pinned_memory` is not specified. c10::optional pinned_memory_opt() const noexcept { - return has_pinned_memory_ ? c10::make_optional(pinned_memory_) : c10::nullopt; + return has_pinned_memory_ ? c10::make_optional(pinned_memory_) + : c10::nullopt; } /// Returns whether the `memory_layout` is specified @@ -363,7 +384,8 @@ struct C10_API TensorOptions { /// Returns the `memory_layout` property of `TensorOptions, or /// `c10::nullopt` if `memory_format` is not specified. c10::optional memory_format_opt() const noexcept { - return has_memory_format_ ? c10::make_optional(memory_format_) : c10::nullopt; + return has_memory_format_ ? c10::make_optional(memory_format_) + : c10::nullopt; } // Resolves the ATen backend specified by the current construction axes. @@ -384,18 +406,25 @@ struct C10_API TensorOptions { /// TensorOptions merge_in(TensorOptions options) const noexcept { TensorOptions merged = *this; - if (options.has_device()) merged.set_device(options.device_opt()); - if (options.has_dtype()) merged.set_dtype(options.dtype_opt()); - if (options.has_layout()) merged.set_layout(options.layout_opt()); + if (options.has_device()) + merged.set_device(options.device_opt()); + if (options.has_dtype()) + merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) + merged.set_layout(options.layout_opt()); // NB: requires grad is right biased; not a logical AND/OR! - if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt()); - if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt()); - if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt()); + if (options.has_requires_grad()) + merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) + merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) + merged.set_memory_format(options.memory_format_opt()); return merged; } // TODO remove after TensorOptions rationalization - TensorOptions merge_memory_format(c10::optional optional_memory_format) const noexcept { + TensorOptions merge_memory_format( + c10::optional optional_memory_format) const noexcept { TensorOptions merged = *this; if (optional_memory_format.has_value()) { merged.set_memory_format(*optional_memory_format); @@ -408,11 +437,11 @@ struct C10_API TensorOptions { // the most part, this just means that this function never returns an // Autograd key) DispatchKey computeDispatchKey() const { - return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); + return c10::computeDispatchKey( + optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); } private: - // These methods are currently private because I'm not sure if it's wise // to actually publish them. They are methods because I need them in // the constructor and the functional API implementation. @@ -515,13 +544,12 @@ struct C10_API TensorOptions { // Bitmask required here to get this to fit inside 32 bits (or even 64 bits, // for that matter) - bool requires_grad_ : 1; - bool pinned_memory_ : 1; + bool requires_grad_ : 1; + bool pinned_memory_ : 1; - - bool has_device_ : 1; - bool has_dtype_ : 1; - bool has_layout_ : 1; + bool has_device_ : 1; + bool has_dtype_ : 1; + bool has_layout_ : 1; bool has_requires_grad_ : 1; bool has_pinned_memory_ : 1; bool has_memory_format_ : 1; @@ -530,8 +558,9 @@ struct C10_API TensorOptions { // We should aspire to fit in one machine-size word; but a size greater than two // words is too much. (We are doing terribly on 32-bit archs, where we require // three machine size words to store tensor options. Eek!) -static_assert( sizeof(TensorOptions) <= sizeof(int64_t) * 2, - "TensorOptions must fit in 128-bits" ); +static_assert( + sizeof(TensorOptions) <= sizeof(int64_t) * 2, + "TensorOptions must fit in 128-bits"); /// Convenience function that returns a `TensorOptions` object with the `dtype` /// set to the given one. @@ -591,88 +620,106 @@ inline std::string toString(const TensorOptions options) { // This is intended to be a centralized location by which we can determine // what an appropriate DispatchKey for a tensor is. -inline DispatchKey computeDispatchKey(c10::optional dtype, c10::optional layout, c10::optional device) { +inline DispatchKey computeDispatchKey( + c10::optional dtype, + c10::optional layout, + c10::optional device) { const auto layout_ = layout_or_default(layout); const auto device_ = device_or_default(device); switch (layout_) { - case Layout::Strided: { - const auto dtype_ = dtype_or_default(dtype); - switch (device_.type()) { - case DeviceType::CPU: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedCPU; - } - return DispatchKey::CPU; + case Layout::Strided: { + const auto dtype_ = dtype_or_default(dtype); + switch (device_.type()) { + case DeviceType::CPU: { + if (isQIntType(dtype_)) { + return DispatchKey::QuantizedCPU; } - case DeviceType::CUDA: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedCUDA; - } - return DispatchKey::CUDA; - } - case DeviceType::XPU: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedXPU; - } - return DispatchKey::XPU; - } - case DeviceType::MKLDNN: - case DeviceType::OPENGL: - case DeviceType::OPENCL: - case DeviceType::IDEEP: - TORCH_INTERNAL_ASSERT(0, "This is a grandfathered Caffe2 device type ", device_.type(), ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error."); - case DeviceType::HIP: - return DispatchKey::HIP; - case DeviceType::FPGA: - return DispatchKey::FPGA; - case DeviceType::MSNPU: - return DispatchKey::MSNPU; - case DeviceType::XLA: - return DispatchKey::XLA; - case DeviceType::MLC: - return DispatchKey::MLC; - case DeviceType::Vulkan: - return DispatchKey::Vulkan; - case DeviceType::Metal: - return DispatchKey::Metal; - case DeviceType::Meta: - return DispatchKey::Meta; - default: - TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for dense layout: ", device_.type()); + return DispatchKey::CPU; } + case DeviceType::CUDA: { + if (isQIntType(dtype_)) { + return DispatchKey::QuantizedCUDA; + } + return DispatchKey::CUDA; + } + case DeviceType::XPU: { + if (isQIntType(dtype_)) { + return DispatchKey::QuantizedXPU; + } + return DispatchKey::XPU; + } + case DeviceType::MKLDNN: + case DeviceType::OPENGL: + case DeviceType::OPENCL: + case DeviceType::IDEEP: + TORCH_INTERNAL_ASSERT( + 0, + "This is a grandfathered Caffe2 device type ", + device_.type(), + ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error."); + case DeviceType::HIP: + return DispatchKey::HIP; + case DeviceType::FPGA: + return DispatchKey::FPGA; + case DeviceType::MSNPU: + return DispatchKey::MSNPU; + case DeviceType::XLA: + return DispatchKey::XLA; + case DeviceType::MLC: + return DispatchKey::MLC; + case DeviceType::Vulkan: + return DispatchKey::Vulkan; + case DeviceType::Metal: + return DispatchKey::Metal; + case DeviceType::Meta: + return DispatchKey::Meta; + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for dense layout: ", + device_.type()); } - case Layout::Sparse: - switch (device_.type()) { - case DeviceType::CPU: - return DispatchKey::SparseCPU; - case DeviceType::CUDA: - return DispatchKey::SparseCUDA; - case DeviceType::HIP: - return DispatchKey::SparseHIP; - case DeviceType::XPU: - return DispatchKey::SparseXPU; - default: - TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for sparse layout: ", device_.type()); - } - case Layout::Mkldnn: - switch (device_.type()) { - case DeviceType::CPU: - return DispatchKey::MkldnnCPU; - default: - TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for mkldnn layout: ", device_.type()); - } - case Layout::SparseCsr: - switch(device_.type()) { - case DeviceType::CPU: - return DispatchKey::SparseCsrCPU; - case DeviceType::CUDA: - return DispatchKey::SparseCsrCUDA; - default: - AT_ERROR("Unsupported device type for sparse CSR layout: ", device_.type()); - } - default: - TORCH_CHECK(false, "Unsupported layout: ", layout_); } + case Layout::Sparse: + switch (device_.type()) { + case DeviceType::CPU: + return DispatchKey::SparseCPU; + case DeviceType::CUDA: + return DispatchKey::SparseCUDA; + case DeviceType::HIP: + return DispatchKey::SparseHIP; + case DeviceType::XPU: + return DispatchKey::SparseXPU; + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for sparse layout: ", + device_.type()); + } + case Layout::Mkldnn: + switch (device_.type()) { + case DeviceType::CPU: + return DispatchKey::MkldnnCPU; + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for mkldnn layout: ", + device_.type()); + } + case Layout::SparseCsr: + switch (device_.type()) { + case DeviceType::CPU: + return DispatchKey::SparseCsrCPU; + case DeviceType::CUDA: + return DispatchKey::SparseCsrCUDA; + default: + AT_ERROR( + "Unsupported device type for sparse CSR layout: ", + device_.type()); + } + default: + TORCH_CHECK(false, "Unsupported layout: ", layout_); + } } inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { @@ -692,7 +739,7 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { } inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { - switch(dispatch_key) { + switch (dispatch_key) { // stuff that's real case DispatchKey::CPU: case DispatchKey::SparseCPU: @@ -730,14 +777,18 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { case DispatchKey::MSNPU: return DeviceType::MSNPU; default: - TORCH_CHECK(false, "DispatchKey ", dispatch_key, " doesn't correspond to a device"); + TORCH_CHECK( + false, + "DispatchKey ", + dispatch_key, + " doesn't correspond to a device"); } } inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) { return TensorOptions() - .layout(dispatchKeyToLayout(dispatch_key)) - .device(dispatchKeyToDeviceType(dispatch_key)); + .layout(dispatchKeyToLayout(dispatch_key)) + .device(dispatchKeyToDeviceType(dispatch_key)); } } // namespace c10 diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index 7252c1cc4dc..dafa0742e75 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -5,7 +5,7 @@ namespace c10 { // should this use the globalContext? Can it get a context passed in somehow? UndefinedTensorImpl::UndefinedTensorImpl() -: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) { + : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) { set_storage_access_should_throw(); } @@ -19,7 +19,8 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const { #ifdef DEBUG bool UndefinedTensorImpl::has_storage() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "UndefinedTensorImpl assumes that storage_ is never set"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !storage_, "UndefinedTensorImpl assumes that storage_ is never set"); return false; } #endif @@ -39,4 +40,4 @@ const char* UndefinedTensorImpl::tensorimpl_type_name() const { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) UndefinedTensorImpl UndefinedTensorImpl::_singleton; -} +} // namespace c10 diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 23d48426dbb..fc650185049 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -7,13 +7,14 @@ namespace c10 { struct C10_API UndefinedTensorImpl final : public TensorImpl { public: // Without this, we get: - // error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in device code + // error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in + // device code // (ostensibly because the constexpr tricks MSVC into trying to compile this // function for device as well). #ifdef _WIN32 - static inline TensorImpl * singleton() { + static inline TensorImpl* singleton() { #else - static constexpr inline TensorImpl * singleton() { + static constexpr inline TensorImpl* singleton() { #endif return &_singleton; } @@ -24,7 +25,8 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { bool has_storage() const override; #endif void set_storage_offset(int64_t offset) override; -private: + + private: UndefinedTensorImpl(); static UndefinedTensorImpl _singleton; const char* tensorimpl_type_name() const override; diff --git a/c10/core/WrapDimMinimal.h b/c10/core/WrapDimMinimal.h index 32d3788c7f1..01cb1c641a1 100644 --- a/c10/core/WrapDimMinimal.h +++ b/c10/core/WrapDimMinimal.h @@ -4,10 +4,17 @@ namespace c10 { -static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) { +static inline int64_t maybe_wrap_dim( + int64_t dim, + int64_t dim_post_expr, + bool wrap_scalar = true) { if (dim_post_expr <= 0) { if (!wrap_scalar) { - TORCH_CHECK_INDEX(false, "dimension specified as ", dim, " but tensor has no dimensions"); + TORCH_CHECK_INDEX( + false, + "dimension specified as ", + dim, + " but tensor has no dimensions"); } dim_post_expr = 1; // this will make range [-1, 0] } @@ -15,12 +22,19 @@ static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wr int64_t min = -dim_post_expr; int64_t max = dim_post_expr - 1; if (dim < min || dim > max) { - TORCH_CHECK_INDEX(false, - "Dimension out of range (expected to be in range of [", - min, ", ", max, "], but got ", dim, ")"); + TORCH_CHECK_INDEX( + false, + "Dimension out of range (expected to be in range of [", + min, + ", ", + max, + "], but got ", + dim, + ")"); } - if (dim < 0) dim += dim_post_expr; + if (dim < 0) + dim += dim_post_expr; return dim; } -} +} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.cpp b/c10/core/impl/DeviceGuardImplInterface.cpp index e44c76376ec..6cee6ed583a 100644 --- a/c10/core/impl/DeviceGuardImplInterface.cpp +++ b/c10/core/impl/DeviceGuardImplInterface.cpp @@ -5,11 +5,15 @@ namespace impl { // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) std::atomic -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -device_guard_impl_registry[static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + device_guard_impl_registry[static_cast( + DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; -DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(DeviceType type, const DeviceGuardImplInterface* impl) { +DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( + DeviceType type, + const DeviceGuardImplInterface* impl) { device_guard_impl_registry[static_cast(type)].store(impl); } -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index b9edf56ddde..e9a3d880c72 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -30,16 +30,16 @@ class DataPtr; * should map one-to-one with actual event flags for those backends. */ enum class EventFlag { - PYTORCH_DEFAULT, - BACKEND_DEFAULT, - // CUDA flags - CUDA_EVENT_DEFAULT, - CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA - // HIP flags - HIP_EVENT_DEFAULT, - HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP - // FOR TESTING ONLY - INVALID + PYTORCH_DEFAULT, + BACKEND_DEFAULT, + // CUDA flags + CUDA_EVENT_DEFAULT, + CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA + // HIP flags + HIP_EVENT_DEFAULT, + HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP + // FOR TESTING ONLY + INVALID }; namespace impl { @@ -126,47 +126,44 @@ struct C10_API DeviceGuardImplInterface { */ virtual Stream exchangeStream(Stream) const noexcept = 0; -/** - * Destroys the given event. - */ - virtual void destroyEvent ( - void* event, - const DeviceIndex device_index) const noexcept { } + /** + * Destroys the given event. + */ + virtual void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept {} -/** - * Increments the event's version and enqueues a job with this version - * in the stream's work queue. When the stream process that job - * it notifies all streams waiting on / blocked by that version of the - * event to continue and marks that version as recorded. - * */ + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ virtual void record( - void** event, - const Stream& stream, - const DeviceIndex device_index, - const c10::EventFlag flag) const { + void** event, + const Stream& stream, + const DeviceIndex device_index, + const c10::EventFlag flag) const { TORCH_CHECK(false, "Backend doesn't support events."); } -/** - * Does nothing if the event has not been scheduled to be recorded. - * If the event was previously enqueued to be recorded, a command - * to wait for the version of the event that exists at the time of this call - * is inserted in the stream's work queue. - * When the stream reaches this command it will stop processing - * additional commands until that version of the event is marked as recorded. - */ - virtual void block( - void* event, - const Stream& stream) const { + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + virtual void block(void* event, const Stream& stream) const { TORCH_CHECK(false, "Backend doesn't support events."); } -/** - * Returns true if (and only if) - * (1) the event has never been scheduled to be recorded - * (2) the current version is marked as recorded. - * Returns false otherwise. - */ + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ virtual bool queryEvent(void* event) const { TORCH_CHECK(false, "Backend doesn't support events."); } @@ -178,13 +175,13 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; - /** * Ensure the caching allocator (if any) is aware that the given DataPtr is * being used on the given stream, and that it should thus avoid recycling the * DataPtr until all work on that stream is done. */ - virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const { } + virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const { + } /** * Intended use of this class is to leak the DeviceGuardImpl at program end. @@ -228,23 +225,21 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface { } // Event-related functions - void record(void** event, - const Stream& stream, - const DeviceIndex device_index, - const EventFlag flag) const override { + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { TORCH_CHECK(false, D, " backend doesn't support events."); } - void block( - void* event, - const Stream& stream) const override { + void block(void* event, const Stream& stream) const override { TORCH_CHECK(false, D, " backend doesn't support events.") } bool queryEvent(void* event) const override { TORCH_CHECK(false, D, " backend doesn't support events.") } - void destroyEvent( - void* event, - const DeviceIndex device_index) const noexcept override { } + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override {} }; // The registry is NON-owning. Each stored pointer is std::atomic so @@ -262,7 +257,8 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface { // putting them in the registry. This is done by deleting the destructor // on DeviceGuardImplInterface. extern C10_API std::atomic -device_guard_impl_registry[static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; + device_guard_impl_registry[static_cast( + DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; // I can't conveniently use c10/util/Registry.h for the following reason: // c10/util/Registry.h gives me a slow way of Create'ing a object of some @@ -273,12 +269,13 @@ device_guard_impl_registry[static_cast(DeviceType::COMPILE_TIME_MAX_DEVI // into device_guard_impl_registry. class C10_API DeviceGuardImplRegistrar { -public: + public: DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*); }; -#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \ - static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl()); +#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \ + static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \ + g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl()); inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) { // Two adjacent int16_t fields DeviceType and DeviceIndex has field access @@ -301,4 +298,5 @@ inline bool hasDeviceGuardImpl(DeviceType type) { return device_guard_impl_registry[static_cast(type)].load(); } -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/FakeGuardImpl.h b/c10/core/impl/FakeGuardImpl.h index 24f48fb8fed..2d47db0fdb1 100644 --- a/c10/core/impl/FakeGuardImpl.h +++ b/c10/core/impl/FakeGuardImpl.h @@ -60,17 +60,16 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface { // Event-related functions void record( - void** event, - const Stream& stream, - const DeviceIndex device_index, - const EventFlag flag) const override { } - void block( - void* event, - const Stream& stream) const override { } - bool queryEvent(void* event) const override { return true; } - void destroyEvent( - void* event, - const DeviceIndex device_index) const noexcept override { } + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override {} + void block(void* event, const Stream& stream) const override {} + bool queryEvent(void* event) const override { + return true; + } + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override {} // Convenience methods for testing static DeviceIndex getDeviceIndex() { @@ -87,9 +86,11 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface { static void resetStreams() { current_streams_.fill(0); } -private: + + private: thread_local static DeviceIndex current_device_; - thread_local static std::array current_streams_; + thread_local static std::array + current_streams_; }; template @@ -99,7 +100,8 @@ template constexpr DeviceType FakeGuardImpl::static_type; template -thread_local std::array FakeGuardImpl::current_streams_ = {0,0,0,0,0,0,0,0}; +thread_local std::array + FakeGuardImpl::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0}; - -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/InlineDeviceGuard.h b/c10/core/impl/InlineDeviceGuard.h index fc3d64444a6..007c9e9bcfc 100644 --- a/c10/core/impl/InlineDeviceGuard.h +++ b/c10/core/impl/InlineDeviceGuard.h @@ -1,18 +1,17 @@ #pragma once -// This file provides implementations of InlineDeviceGuard and InlineOptionalDeviceGuard. +// This file provides implementations of InlineDeviceGuard and +// InlineOptionalDeviceGuard. #include #include #include -#include #include +#include namespace c10 { namespace impl { - - /** * A DeviceGuard is an RAII class that sets a device to some value * on construction, and resets the device to its original value on @@ -54,7 +53,7 @@ namespace impl { */ template class InlineDeviceGuard { -public: + public: // Note [Omitted default constructor from RAII] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // In principle, we could add a default constructor to @@ -69,25 +68,36 @@ public: /// Set the current device to the passed Device. explicit InlineDeviceGuard(Device device) - : impl_(device.type()) - , original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device)) - , current_device_(device.index() == -1 ? original_device_ : device) - {} + : impl_(device.type()), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} /// Set the current device index to the passed DeviceIndex. (The /// device type is inferred from the template parameter T). - template ::value>::type> + template < + typename U = T, + typename = typename std::enable_if< + !std::is_same::value>::type> explicit InlineDeviceGuard(DeviceIndex device_index) - : InlineDeviceGuard(Device(U::static_type, device_index)) {} + : InlineDeviceGuard(Device(U::static_type, device_index)) {} /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit /// DeviceGuardImplInterface pointer. - template ::value>::type> - explicit InlineDeviceGuard(Device device, const DeviceGuardImplInterface* impl) - : impl_(VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))) - , original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device)) - , current_device_(device.index() == -1 ? original_device_ : device) - {} + template < + typename U = T, + typename = typename std::enable_if< + std::is_same::value>::type> + explicit InlineDeviceGuard( + Device device, + const DeviceGuardImplInterface* impl) + : impl_( + VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} /// Copy is disallowed InlineDeviceGuard(const InlineDeviceGuard&) = delete; @@ -103,12 +113,17 @@ public: } /// Sets the device to the given one. - template ::value, int>::type = 0> + template < + typename U = T, + typename std::enable_if::value, int>:: + type = 0> void set_device(at::Device device) { - AT_ASSERT((U::static_type == DeviceType::HIP && device.is_cuda()) || - device.type() == U::static_type); + AT_ASSERT( + (U::static_type == DeviceType::HIP && device.is_cuda()) || + device.type() == U::static_type); auto index = device.index(); - if (index == -1) return; + if (index == -1) + return; impl_.setDevice(device); current_device_ = device; } @@ -116,8 +131,8 @@ public: /// Resets the currently set device to its original device, and then sets the /// current device to the passed device. This is effectively equivalent to /// set_device when a guard supports only a single device type. - template - typename std::enable_if::value >::type + template + typename std::enable_if::value>::type reset_device(at::Device device) { set_device(device); } @@ -139,11 +154,14 @@ public: /// that it is unnecessary. /// /// Optional argument is for testing only. - template - typename std::enable_if::value >::type - reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl = nullptr) { + template + typename std::enable_if::value>::type + reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl = nullptr) { auto index = device.index(); - if (index == -1) return; + if (index == -1) + return; if (device.type() == original_device_.type()) { AT_ASSERT(impl == nullptr || impl->type() == device.type()); impl_.setDevice(device); @@ -175,10 +193,10 @@ public: return current_device_; } -protected: + protected: T impl_; -private: + private: Device original_device_; Device current_device_; }; @@ -193,29 +211,33 @@ private: */ template class InlineOptionalDeviceGuard { -public: + public: // Note [Explicit initialization of optional fields] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Explicit initialization of optional fields - // required to workaround an nvcc bug; see https://github.com/pytorch/pytorch/issues/12117 + // required to workaround an nvcc bug; see + // https://github.com/pytorch/pytorch/issues/12117 /// Creates an uninitialized OptionalDeviceGuard. explicit InlineOptionalDeviceGuard() - : guard_() // See Note [Explicit initialization of optional fields] - {} + : guard_() // See Note [Explicit initialization of optional fields] + {} /// Set the current device to the passed Device, if it is not nullopt. explicit InlineOptionalDeviceGuard(optional device_opt) - : guard_() { // See Note [Explicit initialization of optional fields] + : guard_() { // See Note [Explicit initialization of optional fields] if (device_opt.has_value()) { guard_.emplace(device_opt.value()); } } /// Set the current device to the passed DeviceIndex, if it is not nullopt. - template ::value>::type> + template < + typename U = T, + typename = typename std::enable_if< + !std::is_same::value>::type> explicit InlineOptionalDeviceGuard(optional device_index_opt) - : guard_() { // See Note [Explicit initialization of optional fields] + : guard_() { // See Note [Explicit initialization of optional fields] if (device_index_opt.has_value()) { guard_.emplace(device_index_opt.value()); } @@ -225,7 +247,7 @@ public: /// and result in initialized OptionalDeviceGuard. template explicit InlineOptionalDeviceGuard(Args&&... args) - : guard_(in_place, std::forward(args)...) {} + : guard_(in_place, std::forward(args)...) {} // TODO: Consider readding Tensor and TensorList constructors here, when // Tensor moves to c10. (These are only valid on OptionalDeviceGuard, @@ -313,11 +335,15 @@ public: // // We could solve this with an extra thread-local variable. But no one is // actually using move-assignment. So just get rid of it. - InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete; + InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = + delete; /// Sets the device to the given one. Initializes OptionalDeviceGuard if it /// is not already initialized. - template ::value>::type> + template < + typename U = T, + typename = typename std::enable_if< + !std::is_same::value>::type> void set_device(at::Device device) { if (!guard_.has_value()) { guard_.emplace(device); @@ -333,8 +359,13 @@ public: /// See notes on why this is called reset_device on InlineDeviceGuard. /// /// Optional argument is for testing only. - template ::value>::type> - void reset_device(at::Device device, const DeviceGuardImplInterface* impl = nullptr) { + template < + typename U = T, + typename = typename std::enable_if< + std::is_same::value>::type> + void reset_device( + at::Device device, + const DeviceGuardImplInterface* impl = nullptr) { if (!guard_.has_value()) { guard_.emplace(device, impl); } else { @@ -346,7 +377,10 @@ public: /// current device to the passed device. Initializes the guard if it is /// not already initialized. This is effectively equivalent to set_device /// when a guard supports only a single device type. - template ::value>::type> + template < + typename U = T, + typename = typename std::enable_if< + !std::is_same::value>::type> void reset_device(at::Device device) { if (!guard_.has_value()) { guard_.emplace(device); @@ -357,7 +391,10 @@ public: /// Sets the device index to the given one. The device type is statically /// known. - template ::value >::type> + template < + typename U = T, + typename = typename std::enable_if< + !std::is_same::value>::type> void set_index(DeviceIndex index) { if (!guard_.has_value()) { guard_.emplace(index); @@ -366,17 +403,19 @@ public: } } - /// Returns the device that was set immediately prior to initialization of the, - /// guard, or nullopt if the guard is uninitialized. + /// Returns the device that was set immediately prior to initialization of + /// the, guard, or nullopt if the guard is uninitialized. optional original_device() const { - return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt; + return guard_.has_value() ? make_optional(guard_->original_device()) + : nullopt; } /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device, if the guard is initialized, /// or nullopt if the guard is uninitialized. optional current_device() const { - return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt; + return guard_.has_value() ? make_optional(guard_->current_device()) + : nullopt; } /// Restore the original device, resetting this guard to uninitialized state. @@ -384,8 +423,9 @@ public: guard_.reset(); } -private: + private: optional> guard_; }; -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/InlineEvent.h b/c10/core/impl/InlineEvent.h index 4127b69a78b..003a0b8407a 100644 --- a/c10/core/impl/InlineEvent.h +++ b/c10/core/impl/InlineEvent.h @@ -2,22 +2,19 @@ #include #include -#include #include +#include namespace c10 { namespace impl { template struct InlineEvent final { - InlineEvent() = delete; InlineEvent( - const DeviceType _device_type, - const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) - : backend_{_device_type}, - device_type_{_device_type}, - flag_{_flag} { } + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {} // Copy constructor and copy assignment operator (deleted) InlineEvent(const InlineEvent&) = delete; @@ -25,7 +22,7 @@ struct InlineEvent final { // Move constructor and move assignment operator InlineEvent(InlineEvent&& other) - : InlineEvent(other.device_type_, other.flag_) { + : InlineEvent(other.device_type_, other.flag_) { swap(std::move(other)); } InlineEvent& operator=(InlineEvent&& other) { @@ -43,27 +40,36 @@ struct InlineEvent final { } ~InlineEvent() noexcept { - if (event_) backend_.destroyEvent(event_, device_index_); + if (event_) + backend_.destroyEvent(event_, device_index_); } - DeviceType device_type() const noexcept { return device_type_; } - DeviceIndex device_index() const noexcept { return device_index_; } - EventFlag flag() const noexcept { return flag_; } - bool was_marked_for_recording() const noexcept { return was_marked_for_recording_; } - + DeviceType device_type() const noexcept { + return device_type_; + } + DeviceIndex device_index() const noexcept { + return device_index_; + } + EventFlag flag() const noexcept { + return flag_; + } + bool was_marked_for_recording() const noexcept { + return was_marked_for_recording_; + } void recordOnce(const Stream& stream) { - if (!was_marked_for_recording_) record(stream); + if (!was_marked_for_recording_) + record(stream); } void record(const Stream& stream) { TORCH_CHECK( - stream.device_type() == device_type_, - "Event device type ", - DeviceTypeName(device_type_), - " does not match recording stream's device type ", - DeviceTypeName(stream.device_type()), - "."); + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match recording stream's device type ", + DeviceTypeName(stream.device_type()), + "."); backend_.record(&event_, stream, device_index_, flag_); was_marked_for_recording_ = true; @@ -71,25 +77,27 @@ struct InlineEvent final { } void block(const Stream& stream) const { - if (!was_marked_for_recording_) return; + if (!was_marked_for_recording_) + return; TORCH_CHECK( - stream.device_type() == device_type_, - "Event device type ", - DeviceTypeName(device_type_), - " does not match blocking stream's device type ", - DeviceTypeName(stream.device_type()), - "."); + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match blocking stream's device type ", + DeviceTypeName(stream.device_type()), + "."); backend_.block(event_, stream); } bool query() const { - if (!was_marked_for_recording_) return true; + if (!was_marked_for_recording_) + return true; return backend_.queryEvent(event_); } -private: + private: void* event_ = nullptr; T backend_; DeviceType device_type_; @@ -98,5 +106,5 @@ private: bool was_marked_for_recording_ = false; }; -} // impl -} // c10 +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/InlineStreamGuard.h b/c10/core/impl/InlineStreamGuard.h index 42d4a473c4e..295e3095e7a 100644 --- a/c10/core/impl/InlineStreamGuard.h +++ b/c10/core/impl/InlineStreamGuard.h @@ -16,27 +16,34 @@ namespace impl { */ template class InlineStreamGuard : private InlineDeviceGuard { -public: + public: /// No default constructor, see Note [Omitted default constructor from RAII] explicit InlineStreamGuard() = delete; /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. explicit InlineStreamGuard(Stream stream) - : InlineDeviceGuard(stream.device()) - , original_stream_of_original_device_(this->impl_.getStream(original_device())) - , original_stream_of_current_device_(this->impl_.exchangeStream(stream)) - , current_stream_(stream) - {} + : InlineDeviceGuard(stream.device()), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} /// This constructor exists purely for testing - template ::value>::type> - explicit InlineStreamGuard(Stream stream, const DeviceGuardImplInterface* impl) - : InlineDeviceGuard(stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type())) - , original_stream_of_original_device_(this->impl_.getStream(original_device())) - , original_stream_of_current_device_(this->impl_.exchangeStream(stream)) - , current_stream_(stream) - {} + template < + typename U = T, + typename = typename std::enable_if< + std::is_same::value>::type> + explicit InlineStreamGuard( + Stream stream, + const DeviceGuardImplInterface* impl) + : InlineDeviceGuard( + stream.device(), + impl ? impl : getDeviceGuardImpl(stream.device_type())), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} /// Copy is disallowed InlineStreamGuard(const InlineStreamGuard&) = delete; @@ -110,8 +117,9 @@ public: return InlineDeviceGuard::original_device(); } -private: - Stream original_stream_of_original_device_; // what the user probably cares about + private: + Stream + original_stream_of_original_device_; // what the user probably cares about Stream original_stream_of_current_device_; // what we need to restore Stream current_stream_; }; @@ -123,17 +131,16 @@ private: */ template class InlineOptionalStreamGuard { -public: + public: /// Creates an uninitialized stream guard. explicit InlineOptionalStreamGuard() - : guard_() // See Note [Explicit initialization of optional fields] - {} + : guard_() // See Note [Explicit initialization of optional fields] + {} /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, /// if the passed stream is not nullopt. - explicit InlineOptionalStreamGuard(optional stream_opt) - : guard_() { + explicit InlineOptionalStreamGuard(optional stream_opt) : guard_() { if (stream_opt.has_value()) { guard_.emplace(stream_opt.value()); } @@ -142,13 +149,14 @@ public: /// All constructors of StreamGuard are valid for OptionalStreamGuard template explicit InlineOptionalStreamGuard(Args&&... args) - : guard_(in_place, std::forward(args)...) {} + : guard_(in_place, std::forward(args)...) {} // See Note [Move construction for RAII guards is tricky] InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] - InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete; + InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = + delete; /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, @@ -166,28 +174,31 @@ public: /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. optional original_stream() const { - return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt; + return guard_.has_value() ? make_optional(guard_->original_stream()) + : nullopt; } /// Returns the most recent stream that was set using this stream guard, - /// either from construction, or via reset_stream, if the guard is initialized, - /// or nullopt if the guard is uninitialized. + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. optional current_stream() const { - return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt; + return guard_.has_value() ? make_optional(guard_->current_stream()) + : nullopt; } - /// Restore the original device and stream, resetting this guard to uninitialized state. + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. void reset() { guard_.reset(); } -private: + private: optional> guard_; }; template class InlineMultiStreamGuard { -public: + public: /// Calls `set_stream` on each of the streams in the list. /// This may be useful if you need to set different streams /// for different devices. @@ -216,10 +227,10 @@ public: } } -protected: + protected: optional impl_; -private: + private: /// The original streams that were active on all devices. std::vector original_streams_; @@ -228,13 +239,17 @@ private: DeviceType type = streams[0].device_type(); for (size_t idx = 1; idx < streams.size(); idx++) { TORCH_CHECK_VALUE( - streams[idx].device_type() == type, - "Streams have a mix of device types: stream 0 is on ", - streams[0].device(), " while stream ", idx, " is on device ", - streams[idx].device()); + streams[idx].device_type() == type, + "Streams have a mix of device types: stream 0 is on ", + streams[0].device(), + " while stream ", + idx, + " is on device ", + streams[idx].device()); } return type; } }; -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/LocalDispatchKeySet.cpp b/c10/core/impl/LocalDispatchKeySet.cpp index 16936e4fa62..e0a3c6b46ad 100644 --- a/c10/core/impl/LocalDispatchKeySet.cpp +++ b/c10/core/impl/LocalDispatchKeySet.cpp @@ -8,12 +8,12 @@ namespace impl { // NB: POD, must be zero initialized! // Note [TLS Initialization] // We wanted raw_local_dispatch_key_set to be initialized with non-zero state -// e.g. BackendSelect and InplaceOrView in included set. But certain Windows compiler (e.g the one -// used in ARVR tests) only allow TLS to be zero-initialized. -// To preserve the invariant that raw TLS storage of the default state is zero, -// we obtain the actual include keyset by XORing raw_local_dispatch_key_set.included_ -// with c10::default_included_set. This logic is encapsulated in struct -// PODLocalDispatchKeySet. +// e.g. BackendSelect and InplaceOrView in included set. But certain Windows +// compiler (e.g the one used in ARVR tests) only allow TLS to be +// zero-initialized. To preserve the invariant that raw TLS storage of the +// default state is zero, we obtain the actual include keyset by XORing +// raw_local_dispatch_key_set.included_ with c10::default_included_set. This +// logic is encapsulated in struct PODLocalDispatchKeySet. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; @@ -28,26 +28,29 @@ void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) { raw_local_dispatch_key_set.set_excluded(key_set.excluded_); } -// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as -// opposed to only snapshotting and restoring the state of its assigned DispatchKeySet. -// I'm not sure which is better. If only the RAII API is used, the two choices are -// not distinguishable. +// An RAII guard could snapshot and restore the entire state (entire +// DispatchKeySet) as opposed to only snapshotting and restoring the state of +// its assigned DispatchKeySet. I'm not sure which is better. If only the RAII +// API is used, the two choices are not distinguishable. // -// However, if the guard chooses to snapshot and restore the entire DispatchKeySet, -// the interaction with the non-RAII API changes. Consider this sequence of events: -// - An RAII guard is declared for a particular DispatchKeySet, but snapshots the entire +// However, if the guard chooses to snapshot and restore the entire +// DispatchKeySet, the interaction with the non-RAII API changes. Consider this +// sequence of events: +// - An RAII guard is declared for a particular DispatchKeySet, but snapshots +// the entire // current DispatchKeySet. -// - A call to the non-RAII API changes the state for DispatchKeys outside the assigned +// - A call to the non-RAII API changes the state for DispatchKeys outside the +// assigned // set. -// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted -// (which restores the state for its own assigned DispatchKey and wipes out the state -// for the other DispatchKeys set by the non-RAII API). +// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it +// snapshotted +// (which restores the state for its own assigned DispatchKey and wipes out +// the state for the other DispatchKeys set by the non-RAII API). // RAII API IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include) - : tls_(&raw_local_dispatch_key_set) - , include_(include - tls_->included()) { + : tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) { if (!include_.empty()) { tls_->set_included(tls_->included() | include_); } @@ -60,8 +63,7 @@ IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() { } ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude) - : tls_(&raw_local_dispatch_key_set) - , exclude_(exclude - tls_->excluded()) { + : tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) { if (!exclude_.empty()) { tls_->set_excluded(tls_->excluded() | exclude_); } @@ -74,7 +76,8 @@ ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() { } // Non-RAII API -// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details. +// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h +// for details. bool tls_is_dispatch_key_excluded(DispatchKey x) { return raw_local_dispatch_key_set.excluded().has(x); @@ -115,4 +118,5 @@ bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) { bool tls_is_dispatch_keyset_included(DispatchKeySet ks) { return raw_local_dispatch_key_set.included().isSupersetOf(ks); } -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index b18b4d4de1d..c436dd51af1 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -36,10 +36,12 @@ struct C10_API PODLocalDispatchKeySet { // See Note [TLS Initialization] DispatchKeySet included() const { - return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set; + return DispatchKeySet(DispatchKeySet::RAW, included_) ^ + c10::default_included_set; } DispatchKeySet excluded() const { - return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set; + return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ + c10::default_excluded_set; } void set_included(DispatchKeySet x) { @@ -49,11 +51,13 @@ struct C10_API PODLocalDispatchKeySet { excluded_ = (x ^ c10::default_excluded_set).raw_repr(); } }; -static_assert(std::is_pod::value, "PODLocalDispatchKeySet must be a POD type."); +static_assert( + std::is_pod::value, + "PODLocalDispatchKeySet must be a POD type."); struct C10_API LocalDispatchKeySet { /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x) - : included_(x.included()), excluded_(x.excluded()) {} + : included_(x.included()), excluded_(x.excluded()) {} DispatchKeySet included_; DispatchKeySet excluded_; }; @@ -63,7 +67,7 @@ struct C10_API LocalDispatchKeySet { #if defined(_MSC_VER) || defined(C10_ANDROID) C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); #else // defined(_MSC_VER) || defined(C10_ANDROID) - extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; +extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { // Don't let people fiddle with the thread_local directly just @@ -78,15 +82,17 @@ C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); // RAII API for manipulating the thread-local dispatch state. class C10_API IncludeDispatchKeyGuard { -public: + public: IncludeDispatchKeyGuard(DispatchKeySet); - IncludeDispatchKeyGuard(DispatchKey k) : IncludeDispatchKeyGuard(DispatchKeySet(k)) {} + IncludeDispatchKeyGuard(DispatchKey k) + : IncludeDispatchKeyGuard(DispatchKeySet(k)) {} IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete; IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete; IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete; IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete; ~IncludeDispatchKeyGuard(); -private: + + private: // A little micro-optimization to save us from tls_get_addr call // on destruction PODLocalDispatchKeySet* tls_; @@ -94,15 +100,17 @@ private: }; class C10_API ExcludeDispatchKeyGuard { -public: + public: ExcludeDispatchKeyGuard(DispatchKeySet); - ExcludeDispatchKeyGuard(DispatchKey k) : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {} + ExcludeDispatchKeyGuard(DispatchKey k) + : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {} ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete; ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete; ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete; ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete; ~ExcludeDispatchKeyGuard(); -private: + + private: // A little micro-optimization to save us from tls_get_addr call // on destruction PODLocalDispatchKeySet* tls_; @@ -120,7 +128,8 @@ private: // through that DispatchKey's registered overrides. // // The non-RAII API is less efficient than the RAII guards because both the -// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!) +// getter and setter will do a tls_getaddr lookup (the RAII struct only needs +// one!) C10_API bool tls_is_dispatch_key_excluded(DispatchKey x); C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state); @@ -129,4 +138,5 @@ C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state); C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks); C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks); -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.cpp b/c10/core/impl/SizesAndStrides.cpp index f6f5abcb5bb..db1d7c61e98 100644 --- a/c10/core/impl/SizesAndStrides.cpp +++ b/c10/core/impl/SizesAndStrides.cpp @@ -3,9 +3,13 @@ namespace c10 { namespace impl { -void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) { +void SizesAndStrides::resizeSlowPath( + const size_t newSize, + const size_t oldSize) { if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !isInline(), + "resizeSlowPath called when fast path should have been hit!"); int64_t* tempStorage = outOfLineStorage_; memcpy( &inlineStorage_[0], @@ -24,15 +28,23 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD // OVERWRITE inlineStorage_! // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - int64_t* tempStorage = static_cast(malloc(storageBytes(newSize))); - TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); + int64_t* tempStorage = + static_cast(malloc(storageBytes(newSize))); + TORCH_CHECK( + tempStorage, + "Could not allocate memory to change Tensor SizesAndStrides!"); const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); - const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0; + const auto bytesToZero = (newSize > oldSize) + ? (newSize - oldSize) * sizeof(tempStorage[0]) + : 0; memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); if (bytesToZero) { memset(&tempStorage[oldSize], 0, bytesToZero); } - memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy); + memcpy( + &tempStorage[newSize], + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + bytesToCopy); if (bytesToZero) { memset(&tempStorage[newSize + oldSize], 0, bytesToZero); } @@ -55,7 +67,8 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) resizeOutOfLineStorage(newSize); } else { // Zero the end of the sizes portion. - const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + const auto bytesToZero = + (newSize - oldSize) * sizeof(outOfLineStorage_[0]); memset(&outOfLineStorage_[oldSize], 0, bytesToZero); memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); } diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index 4f7e19330ac..330779bad5e 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -19,7 +19,8 @@ namespace impl { // the number of strides. The memory layout is as follows: // // 1 size_t for the size -// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer +// to out-of-line array class C10_API SizesAndStrides { public: // TODO: different iterator types for sizes & strides to prevent @@ -238,11 +239,16 @@ class C10_API SizesAndStrides { if (newSize == oldSize) { return; } - if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (C10_LIKELY( + newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { if (oldSize < newSize) { - const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + const auto bytesToZero = + (newSize - oldSize) * sizeof(inlineStorage_[0]); memset(&inlineStorage_[oldSize], 0, bytesToZero); - memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero); + memset( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + 0, + bytesToZero); } size_ = newSize; } else { @@ -267,14 +273,19 @@ class C10_API SizesAndStrides { } void allocateOutOfLineStorage(size_t size) { - outOfLineStorage_ = static_cast(malloc(storageBytes(size))); - TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); } void resizeOutOfLineStorage(size_t newSize) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); - outOfLineStorage_ = static_cast(realloc(outOfLineStorage_, storageBytes(newSize))); - TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + outOfLineStorage_ = static_cast( + realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); } void copyDataOutline(const SizesAndStrides& rhs) noexcept { @@ -283,10 +294,9 @@ class C10_API SizesAndStrides { size_t size_; union { - int64_t *outOfLineStorage_; + int64_t* outOfLineStorage_; int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; }; - }; } // namespace impl diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 1334cd7cfe9..896a11aa350 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -10,12 +10,11 @@ namespace impl { * to virtual dispatch on the DeviceGuardImpl registry. */ class VirtualGuardImpl final : public DeviceGuardImplInterface { -public: + public: VirtualGuardImpl(DeviceType device_type) - : impl_(getDeviceGuardImpl(device_type)) {} + : impl_(getDeviceGuardImpl(device_type)) {} // This constructor exists purely for testing - VirtualGuardImpl(const DeviceGuardImplInterface* impl) - : impl_(impl) {} + VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {} // Copying and moving is OK! @@ -51,34 +50,32 @@ public: } // Event functions - void record(void** event, - const Stream& stream, - const DeviceIndex device_index, - const EventFlag flag) const override { + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { impl_->record(event, stream, device_index, flag); } - void block( - void* event, - const Stream& stream) const override { + void block(void* event, const Stream& stream) const override { impl_->block(event, stream); } bool queryEvent(void* event) const override { return impl_->queryEvent(event); } - void destroyEvent( - void* event, - const DeviceIndex device_index) const noexcept override { + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { impl_->destroyEvent(event, device_index); } - void recordDataPtrOnStream( - const c10::DataPtr& data_ptr, - const Stream& stream) const override { + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { impl_->recordDataPtrOnStream(data_ptr, stream); } -private: + private: const DeviceGuardImplInterface* impl_ = nullptr; }; -}} // namespace c10::impl +} // namespace impl +} // namespace c10 diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index b567c018a68..122732fd1b2 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -3,9 +3,9 @@ namespace c10 { ThreadPool::ThreadPool( - int pool_size, - int numa_node_id, - std::function init_thread) + int pool_size, + int numa_node_id, + std::function init_thread) : threads_(pool_size < 0 ? defaultNumThreads() : pool_size), running_(true), complete_(true), @@ -13,7 +13,7 @@ ThreadPool::ThreadPool( total_(threads_.size()), numa_node_id_(numa_node_id) { for (std::size_t i = 0; i < threads_.size(); ++i) { - threads_[i] = std::thread([this, i, init_thread](){ + threads_[i] = std::thread([this, i, init_thread]() { if (init_thread) { init_thread(); } diff --git a/c10/core/thread_pool.h b/c10/core/thread_pool.h index 33a60aa14cc..3fadd2add89 100644 --- a/c10/core/thread_pool.h +++ b/c10/core/thread_pool.h @@ -50,9 +50,9 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { const std::function with_id; explicit task_element_t(std::function f) - : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} + : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} explicit task_element_t(std::function f) - : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} + : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} }; std::queue tasks_; @@ -105,13 +105,11 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { class C10_API TaskThreadPool : public c10::ThreadPool { public: - explicit TaskThreadPool( - std::size_t pool_size, - int numa_node_id = -1) - : ThreadPool(pool_size, numa_node_id, [numa_node_id](){ - setThreadName("CaffeTaskThread"); - NUMABind(numa_node_id); - }) {} + explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1) + : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { + setThreadName("CaffeTaskThread"); + NUMABind(numa_node_id); + }) {} }; C10_DECLARE_SHARED_REGISTRY( diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4e3b04165cd..2b303b8960d 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1,9 +1,9 @@ #include -#include #include #include +#include #include #include @@ -56,47 +56,57 @@ namespace CUDACachingAllocator { /** * Note [Interaction with CUDA graph capture] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * Graph capture performs a dry run of a region of execution, freezing all CUDA work - * (and virtual addresses used during that work) into a "graph." The graph may be "replayed" - * like a single giant kernel, with greatly reduced CPU overhead as well as modestly improved - * GPU performance. + * Graph capture performs a dry run of a region of execution, freezing all CUDA + * work (and virtual addresses used during that work) into a "graph." The graph + * may be "replayed" like a single giant kernel, with greatly reduced CPU + * overhead as well as modestly improved GPU performance. * - * Because capture bakes in memory addresses, the memory used during capture must be available - * for the graph to use during replay. DeviceCachingAllocator assigns and frees memory eagerly - * and dynamically, so if we're not careful about managing graphs' memory, at replay time those - * memory addresses could be use by other tensors. + * Because capture bakes in memory addresses, the memory used during capture + * must be available for the graph to use during replay. DeviceCachingAllocator + * assigns and frees memory eagerly and dynamically, so if we're not careful + * about managing graphs' memory, at replay time those memory addresses could be + * use by other tensors. * - * To guarantee a graph's baked in addresses are safe to reuse in replay, DeviceAllocator - * satisfies allocations from a graph-private memory pool during capture, and doesn't begin - * cudaFreeing those addresses until the graph is destroyed. + * To guarantee a graph's baked in addresses are safe to reuse in replay, + * DeviceAllocator satisfies allocations from a graph-private memory pool during + * capture, and doesn't begin cudaFreeing those addresses until the graph is + * destroyed. * - * Within the private pool, allocations are freed and reassigned as usual during capture. - * Memory regions will be used in a consistent order during replay. - * So a private pool doesn't use memory more wastefully than the default pools during capture, - * but it does reserve its high-water mark of used memory away from the default pools as long - * as the capture(s) it served survive (regardless whether those captures are idle or replaying). + * Within the private pool, allocations are freed and reassigned as usual during + * capture. Memory regions will be used in a consistent order during replay. So + * a private pool doesn't use memory more wastefully than the default pools + * during capture, but it does reserve its high-water mark of used memory away + * from the default pools as long as the capture(s) it served survive + * (regardless whether those captures are idle or replaying). * - * CUDAGraph's requests for private pools are mediated by DeviceAllocator::notifyCaptureBegin, - * notifyCaptureEnd, and notifyCaptureDestroy. + * CUDAGraph's requests for private pools are mediated by + * DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and + * notifyCaptureDestroy. */ namespace { using stream_set = std::unordered_set; -constexpr size_t kMinBlockSize = 512; // all sizes are rounded to at least 512 bytes -constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB -constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks -constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks -constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer -constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB +constexpr size_t kMinBlockSize = + 512; // all sizes are rounded to at least 512 bytes +constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB +constexpr size_t kSmallBuffer = + 2097152; // "small" allocations are packed in 2 MiB blocks +constexpr size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks +constexpr size_t kMinLargeAlloc = + 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer +constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB typedef std::bitset(StatType::NUM_TYPES)> StatTypes; void update_stat(Stat& stat, int64_t amount) { stat.current += amount; - TORCH_INTERNAL_ASSERT(stat.current >= 0, "Negative tracked stat in CUDA allocator (likely logic error)."); + TORCH_INTERNAL_ASSERT( + stat.current >= 0, + "Negative tracked stat in CUDA allocator (likely logic error)."); stat.peak = std::max(stat.current, stat.peak); if (amount > 0) { @@ -116,7 +126,10 @@ void reset_peak_stat(Stat& stat) { stat.peak = stat.current; } -void update_stat_array(StatArray& stat_array, int64_t amount, const StatTypes& stat_types) { +void update_stat_array( + StatArray& stat_array, + int64_t amount, + const StatTypes& stat_types) { for (size_t stat_type = 0; stat_type < stat_types.size(); ++stat_type) { if (stat_types[stat_type]) { update_stat(stat_array[stat_type], amount); @@ -129,45 +142,64 @@ struct PrivatePool; typedef bool (*Comparison)(const Block*, const Block*); struct BlockPool { - BlockPool(Comparison comparator, - bool small, - PrivatePool* private_pool=nullptr) : - blocks(comparator), - is_small(small), - owner_PrivatePool(private_pool) {} + BlockPool( + Comparison comparator, + bool small, + PrivatePool* private_pool = nullptr) + : blocks(comparator), is_small(small), owner_PrivatePool(private_pool) {} std::set blocks; const bool is_small; PrivatePool* owner_PrivatePool; }; struct Block { - int device; // gpu - cudaStream_t stream; // allocation stream - stream_set stream_uses; // streams on which the block was used - size_t size; // block size in bytes - BlockPool* pool; // owning memory pool - void* ptr; // memory address - bool allocated; // in-use flag - Block* prev; // prev block if split from a larger allocation - Block* next; // next block if split from a larger allocation - int event_count; // number of outstanding CUDA events + int device; // gpu + cudaStream_t stream; // allocation stream + stream_set stream_uses; // streams on which the block was used + size_t size; // block size in bytes + BlockPool* pool; // owning memory pool + void* ptr; // memory address + bool allocated; // in-use flag + Block* prev; // prev block if split from a larger allocation + Block* next; // next block if split from a larger allocation + int event_count; // number of outstanding CUDA events - Block(int device, cudaStream_t stream, size_t size, BlockPool* pool, void* ptr) : - device(device), stream(stream), stream_uses(), size(size), pool(pool), - ptr(ptr), allocated(0), prev(nullptr), next(nullptr), event_count(0) { } + Block( + int device, + cudaStream_t stream, + size_t size, + BlockPool* pool, + void* ptr) + : device(device), + stream(stream), + stream_uses(), + size(size), + pool(pool), + ptr(ptr), + allocated(0), + prev(nullptr), + next(nullptr), + event_count(0) {} // constructor for search key - Block(int device, cudaStream_t stream, size_t size) : - device(device), stream(stream), stream_uses(), size(size), pool(nullptr), - ptr(nullptr), allocated(0), prev(nullptr), next(nullptr), event_count(0) { } + Block(int device, cudaStream_t stream, size_t size) + : device(device), + stream(stream), + stream_uses(), + size(size), + pool(nullptr), + ptr(nullptr), + allocated(0), + prev(nullptr), + next(nullptr), + event_count(0) {} bool is_split() const { return (prev != nullptr) || (next != nullptr); } }; -static bool BlockComparator(const Block* a, const Block* b) -{ +static bool BlockComparator(const Block* a, const Block* b) { if (a->stream != b->stream) { return (uintptr_t)a->stream < (uintptr_t)b->stream; } @@ -197,17 +229,28 @@ static std::string format_size(uint64_t size) { } struct AllocParams { - AllocParams(int device, size_t size, cudaStream_t stream, BlockPool* pool, size_t alloc_size, - DeviceStats& stats) : - search_key(device, stream, size), - pool(pool), - alloc_size(alloc_size), - block(nullptr), - err(cudaSuccess) {} + AllocParams( + int device, + size_t size, + cudaStream_t stream, + BlockPool* pool, + size_t alloc_size, + DeviceStats& stats) + : search_key(device, stream, size), + pool(pool), + alloc_size(alloc_size), + block(nullptr), + err(cudaSuccess) {} - int device() { return search_key.device; } - cudaStream_t stream() { return search_key.stream; } - size_t size() { return search_key.size; } + int device() { + return search_key.device; + } + cudaStream_t stream() { + return search_key.stream; + } + size_t size() { + return search_key.size; + } Block search_key; BlockPool* pool; @@ -217,26 +260,27 @@ struct AllocParams { cudaError_t err; }; - // CUDA graphs helper struct PrivatePool { - PrivatePool() : - use_count(1), - cudaMalloc_count(0), - large_blocks(BlockComparator, /*is_small=*/false, this), - small_blocks(BlockComparator, /*is_small=*/true, this) {} + PrivatePool() + : use_count(1), + cudaMalloc_count(0), + large_blocks(BlockComparator, /*is_small=*/false, this), + small_blocks(BlockComparator, /*is_small=*/true, this) {} PrivatePool(const PrivatePool&) = delete; PrivatePool(PrivatePool&&) = delete; PrivatePool& operator=(const PrivatePool&) = delete; // Number of live graphs using this pool int use_count; - // Number of unfreed cudaMallocs made for this pool. When use_count and cudaMalloc_count - // drop to zero, we can delete this PrivatePool from graph_pools. + // Number of unfreed cudaMallocs made for this pool. When use_count and + // cudaMalloc_count drop to zero, we can delete this PrivatePool from + // graph_pools. int cudaMalloc_count; - // Instead of maintaining private BlockPools here, I could stuff all blocks (private or no) - // into the top-level large_blocks and small_blocks, and distinguish private blocks by adding - // a "pool id" check above the stream check in BlockComparator. BlockComparator is performance- - // critial though, I'd rather not add more logic to it. + // Instead of maintaining private BlockPools here, I could stuff all blocks + // (private or no) into the top-level large_blocks and small_blocks, and + // distinguish private blocks by adding a "pool id" check above the stream + // check in BlockComparator. BlockComparator is performance- critial though, + // I'd rather not add more logic to it. BlockPool large_blocks; BlockPool small_blocks; }; @@ -249,12 +293,14 @@ struct MempoolIdHash { cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::currentStreamCaptureStatusMayInitCtx() == at::cuda::CaptureStatus::None) { + if (at::cuda::currentStreamCaptureStatusMayInitCtx() == + at::cuda::CaptureStatus::None) { #endif return cudaMalloc(p, size); #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 } else { - // It's ok to capture cudaMallocs, as long as we never cudaFree those addresses before replay. + // It's ok to capture cudaMallocs, as long as we never cudaFree those + // addresses before replay. at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; return cudaMalloc(p, size); } @@ -264,9 +310,7 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { } // namespace class DeviceCachingAllocator { - private: - // lock around all operations mutable std::recursive_mutex mutex; @@ -302,38 +346,41 @@ class DeviceCachingAllocator { // Members specific to CUDA graphs // Private pools for CUDA graphs - std::unordered_map, MempoolIdHash> graph_pools; - // Pools no longer referenced by any graph. Their BlockPools are eligible for free_blocks. - // Can't be a vector or deque because we might erase entries in any order. - // Could be an std::list, but we don't care much, access and insert/erase are rare. - std::unordered_map graph_pools_freeable; + std::unordered_map, MempoolIdHash> + graph_pools; + // Pools no longer referenced by any graph. Their BlockPools are eligible for + // free_blocks. Can't be a vector or deque because we might erase entries in + // any order. Could be an std::list, but we don't care much, access and + // insert/erase are rare. + std::unordered_map + graph_pools_freeable; // Maps a capturing stream to its assigned private pool, // in case we want multiple captures to share the same pool std::unordered_map capture_to_pool_map; public: - - DeviceCachingAllocator() : - large_blocks(BlockComparator, /*is_small=*/false), - small_blocks(BlockComparator, /*is_small=*/true) {} + DeviceCachingAllocator() + : large_blocks(BlockComparator, /*is_small=*/false), + small_blocks(BlockComparator, /*is_small=*/true) {} // All public methods (except the above) acquire the allocator mutex. // Thus, do not call a public method from another public method. - Block* malloc(int device, size_t size, cudaStream_t stream) - { + Block* malloc(int device, size_t size, cudaStream_t stream) { std::unique_lock lock(mutex); if (C10_LIKELY(captures_underway == 0)) { - // Processes end-of-life events for outstanding allocations used on multiple streams - // (checks if their GPU-side uses are complete and recycles their memory if so) + // Processes end-of-life events for outstanding allocations used on + // multiple streams (checks if their GPU-side uses are complete and + // recycles their memory if so) // // Q. Why skip process_events if a capture might be underway? - // A. process_events involves cudaEventQueries, illegal during CUDA graph capture. - // Dumb simple solution: defer reclaiming these allocations until after capture. - // Cross-stream memory use is uncommon, so the deferral's effect on memory use - // during capture should be small. + // A. process_events involves cudaEventQueries, illegal during CUDA graph + // capture. + // Dumb simple solution: defer reclaiming these allocations until after + // capture. Cross-stream memory use is uncommon, so the deferral's + // effect on memory use during capture should be small. process_events(); } @@ -345,14 +392,14 @@ class DeviceCachingAllocator { params.stat_types[static_cast(get_stat_type_for_pool(pool))] = true; bool block_found = - // Search pool - get_free_block(params) - // Trigger callbacks and retry search - || (trigger_free_memory_callbacks(params) && get_free_block(params)) - // Attempt allocate - || alloc_block(params, false) - // Free all non-split cached blocks and retry alloc. - || (free_cached_blocks() && alloc_block(params, true)); + // Search pool + get_free_block(params) + // Trigger callbacks and retry search + || (trigger_free_memory_callbacks(params) && get_free_block(params)) + // Attempt allocate + || alloc_block(params, false) + // Free all non-split cached blocks and retry alloc. + || (free_cached_blocks() && alloc_block(params, true)); if (!block_found) { // For any error code other than cudaErrorMemoryAllocation, @@ -388,21 +435,32 @@ class DeviceCachingAllocator { // Note that at this point free_cached_blocks has already returned all // possible "cached" memory to the driver. The only remaining "cached" // memory is split from a larger block that is partially in-use. - TORCH_CHECK_WITH(CUDAOutOfMemoryError, false, - "CUDA out of memory. Tried to allocate ", format_size(alloc_size), - " (GPU ", device, "; ", - format_size(device_total), " total capacity; ", - format_size(stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current), - " already allocated; ", - format_size(device_free), " free; ", - allowed_info, - format_size(stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current), - " reserved in total by PyTorch)"); + TORCH_CHECK_WITH( + CUDAOutOfMemoryError, + false, + "CUDA out of memory. Tried to allocate ", + format_size(alloc_size), + " (GPU ", + device, + "; ", + format_size(device_total), + " total capacity; ", + format_size( + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current), + " already allocated; ", + format_size(device_free), + " free; ", + allowed_info, + format_size( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current), + " reserved in total by PyTorch)"); } - TORCH_INTERNAL_ASSERT(params.err == cudaSuccess && - params.block != nullptr && - params.block->ptr != nullptr); + TORCH_INTERNAL_ASSERT( + params.err == cudaSuccess && params.block != nullptr && + params.block->ptr != nullptr); Block* block = params.block; Block* remaining = nullptr; @@ -425,16 +483,19 @@ class DeviceCachingAllocator { if (already_split) { // An already-split inactive block is being shrunk by size bytes. - update_stat_array(stats.inactive_split_bytes, -block->size, params.stat_types); + update_stat_array( + stats.inactive_split_bytes, -block->size, params.stat_types); } else { - // A new split inactive block is being created from a previously unsplit block, - // size remaining->size bytes. - update_stat_array(stats.inactive_split_bytes, remaining->size, params.stat_types); + // A new split inactive block is being created from a previously unsplit + // block, size remaining->size bytes. + update_stat_array( + stats.inactive_split_bytes, remaining->size, params.stat_types); update_stat_array(stats.inactive_split, 1, params.stat_types); } } else if (already_split) { // An already-split block is becoming active - update_stat_array(stats.inactive_split_bytes, -block->size, params.stat_types); + update_stat_array( + stats.inactive_split_bytes, -block->size, params.stat_types); update_stat_array(stats.inactive_split, -1, params.stat_types); } @@ -453,8 +514,7 @@ class DeviceCachingAllocator { return block; } - void free(Block* block) - { + void free(Block* block) { std::lock_guard lock(mutex); block->allocated = false; @@ -464,15 +524,17 @@ class DeviceCachingAllocator { StatTypes stat_types; stat_types[static_cast(StatType::AGGREGATE)] = true; - stat_types[static_cast(get_stat_type_for_pool(*(block->pool)))] = true; + stat_types[static_cast(get_stat_type_for_pool(*(block->pool)))] = + true; update_stat_array(stats.allocation, -1, {stat_types}); update_stat_array(stats.allocated_bytes, -block->size, {stat_types}); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(captures_underway)) { - // It's forbidden to cudaEventQuery an event recorded during CUDA graph capture. - // We conservatively defer recording end-of-life events until the next call to - // process_events() (which won't happen until no captures are underway) + // It's forbidden to cudaEventQuery an event recorded during CUDA graph + // capture. We conservatively defer recording end-of-life events until + // the next call to process_events() (which won't happen until no + // captures are underway) needs_events_deferred_until_no_capture.push_back(block); } else { insert_events(block); @@ -487,7 +549,7 @@ class DeviceCachingAllocator { while (block->prev) { block = block->prev; } - void *basePtr = block->ptr; + void* basePtr = block->ptr; if (outSize) { size_t size = 0; while (block) { @@ -525,13 +587,14 @@ class DeviceCachingAllocator { } /** Retrieves info (total size + largest block) of the memory cache **/ - void cacheInfo(size_t* total, size_t* largest) - { + void cacheInfo(size_t* total, size_t* largest) { std::lock_guard lock(mutex); - if (*largest == 0) { // make an initial guess if a zero *largest is passed in + if (*largest == + 0) { // make an initial guess if a zero *largest is passed in size_t tmp_bytes; - cudaMemGetInfo(largest, // Use free memory as an optimistic initial guess of *largest - &tmp_bytes); + cudaMemGetInfo( + largest, // Use free memory as an optimistic initial guess of *largest + &tmp_bytes); } cache_info_aux(large_blocks, total, largest); cache_info_aux(small_blocks, total, largest); @@ -551,7 +614,9 @@ class DeviceCachingAllocator { void resetAccumulatedStats() { std::lock_guard lock(mutex); - for (size_t statType = 0; statType < static_cast(StatType::NUM_TYPES); ++statType) { + for (size_t statType = 0; + statType < static_cast(StatType::NUM_TYPES); + ++statType) { reset_accumulated_stat(stats.allocation[statType]); reset_accumulated_stat(stats.segment[statType]); reset_accumulated_stat(stats.active[statType]); @@ -570,7 +635,9 @@ class DeviceCachingAllocator { void resetPeakStats() { std::lock_guard lock(mutex); - for (size_t statType = 0; statType < static_cast(StatType::NUM_TYPES); ++statType) { + for (size_t statType = 0; + statType < static_cast(StatType::NUM_TYPES); + ++statType) { reset_peak_stat(stats.allocation[statType]); reset_peak_stat(stats.segment[statType]); reset_peak_stat(stats.active[statType]); @@ -582,7 +649,8 @@ class DeviceCachingAllocator { } } - /** Dump a complete snapshot of the memory held by the allocator. Potentially VERY expensive. **/ + /** Dump a complete snapshot of the memory held by the allocator. Potentially + * VERY expensive. **/ std::vector snapshot() const { std::lock_guard lock(mutex); @@ -606,7 +674,8 @@ class DeviceCachingAllocator { block_info.size = block->size; block_info.allocated = block->allocated; - block_info.active = block->allocated || (block->event_count > 0) || !block->stream_uses.empty(); + block_info.active = block->allocated || (block->event_count > 0) || + !block->stream_uses.empty(); segment_info.total_size += block_info.size; if (block_info.allocated) { @@ -620,9 +689,12 @@ class DeviceCachingAllocator { } } - std::sort(result.begin(), result.end(), [](const SegmentInfo& a, const SegmentInfo& b) { - return a.address < b.address; - }); + std::sort( + result.begin(), + result.end(), + [](const SegmentInfo& a, const SegmentInfo& b) { + return a.address < b.address; + }); return result; } @@ -643,22 +715,26 @@ class DeviceCachingAllocator { captures_underway++; auto it = graph_pools.find(mempool_id); if (it == graph_pools.end()) { - // mempool_id does not reference an existing pool. Make a new pool for this capture. - graph_pools.emplace(std::make_pair(mempool_id, std::unique_ptr(new PrivatePool))); + // mempool_id does not reference an existing pool. Make a new pool for + // this capture. + graph_pools.emplace(std::make_pair( + mempool_id, std::unique_ptr(new PrivatePool))); } else { - // mempool_id references an existing pool, which the current capture will share. - // Check this pool is live (at least one other capture already references it). + // mempool_id references an existing pool, which the current capture will + // share. Check this pool is live (at least one other capture already + // references it). TORCH_INTERNAL_ASSERT(it->second->use_count > 0); it->second->use_count++; } - // Maps this graph_id to mempool_id and makes sure this graph_id wasn't somehow - // assigned a mempool_id already. Keeps essential effect (insert) out of macro. + // Maps this graph_id to mempool_id and makes sure this graph_id wasn't + // somehow assigned a mempool_id already. Keeps essential effect (insert) + // out of macro. bool inserted = capture_to_pool_map.insert({graph_id, mempool_id}).second; TORCH_INTERNAL_ASSERT(inserted); } // Called by CUDAGraph::capture_end - void notifyCaptureEnd(CaptureId_t graph_id) { + void notifyCaptureEnd(CaptureId_t graph_id) { std::lock_guard lock(mutex); captures_underway--; auto it = capture_to_pool_map.find(graph_id); @@ -669,13 +745,15 @@ class DeviceCachingAllocator { // Called by CUDAGraph::reset void notifyCaptureDestroy(MempoolId_t mempool_id) { std::lock_guard lock(mutex); - // The instantiated cudaGraphExec_t has been destroyed. We can't blindly delete and cudaFree - // the mempool its capture used, because + // The instantiated cudaGraphExec_t has been destroyed. We can't blindly + // delete and cudaFree the mempool its capture used, because // 1. other graph(s) might share the same pool - // 2. the user might still hold references to output tensors allocated during capture. - // To handle 1 and 2, we track the number of graphs using this particular mempool. - // When the count reaches 0, we tell free_cached_blocks it may now cudaFree blocks from - // this graph's pool when it discovers they're unused (unsplit). + // 2. the user might still hold references to output tensors allocated + // during capture. + // To handle 1 and 2, we track the number of graphs using this particular + // mempool. When the count reaches 0, we tell free_cached_blocks it may now + // cudaFree blocks from this graph's pool when it discovers they're unused + // (unsplit). auto it = graph_pools.find(mempool_id); TORCH_INTERNAL_ASSERT(it != graph_pools.end()); auto uc = --(it->second->use_count); @@ -683,31 +761,40 @@ class DeviceCachingAllocator { if (uc == 0) { // Allows free_cached_blocks to begin cudaFreeing this pool's memory, // and makes sure this pool wasn't somehow made freeable already. - bool inserted = graph_pools_freeable.insert({mempool_id, it->second.get()}).second; + bool inserted = + graph_pools_freeable.insert({mempool_id, it->second.get()}).second; TORCH_INTERNAL_ASSERT(inserted); } } private: - // All private methods do not acquire the allocator mutex. std::vector get_all_blocks() const { std::vector blocks; - blocks.insert(blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); - blocks.insert(blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); + blocks.insert( + blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); + blocks.insert( + blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); for (const auto& gp : graph_pools) { - blocks.insert(blocks.end(), gp.second->small_blocks.blocks.begin(), gp.second->small_blocks.blocks.end()); - blocks.insert(blocks.end(), gp.second->large_blocks.blocks.begin(), gp.second->large_blocks.blocks.end()); + blocks.insert( + blocks.end(), + gp.second->small_blocks.blocks.begin(), + gp.second->small_blocks.blocks.end()); + blocks.insert( + blocks.end(), + gp.second->large_blocks.blocks.begin(), + gp.second->large_blocks.blocks.end()); } blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end()); return blocks; } /** moves a block into a pool of cached free blocks */ - void free_block(Block* block) - { - TORCH_INTERNAL_ASSERT(!block->allocated && block->event_count == 0 && block->stream_uses.empty()); + void free_block(Block* block) { + TORCH_INTERNAL_ASSERT( + !block->allocated && block->event_count == 0 && + block->stream_uses.empty()); size_t original_block_size = block->size; @@ -717,7 +804,8 @@ class DeviceCachingAllocator { const std::array merge_candidates = {block->prev, block->next}; for (Block* merge_candidate : merge_candidates) { - const int64_t subsumed_size = try_merge_blocks(block, merge_candidate, pool); + const int64_t subsumed_size = + try_merge_blocks(block, merge_candidate, pool); if (subsumed_size > 0) { net_change_inactive_split_blocks -= 1; net_change_inactive_split_size -= subsumed_size; @@ -725,7 +813,8 @@ class DeviceCachingAllocator { } active_blocks.erase(block); - // Makes sure the Block* isn't already present in the pool we're freeing it back into. + // Makes sure the Block* isn't already present in the pool we're freeing it + // back into. bool inserted = pool.blocks.insert(block).second; TORCH_INTERNAL_ASSERT(inserted); @@ -737,16 +826,19 @@ class DeviceCachingAllocator { StatTypes stat_types; stat_types[static_cast(StatType::AGGREGATE)] = true; stat_types[static_cast(get_stat_type_for_pool(pool))] = true; - update_stat_array(stats.inactive_split, net_change_inactive_split_blocks, stat_types); - update_stat_array(stats.inactive_split_bytes, net_change_inactive_split_size, stat_types); + update_stat_array( + stats.inactive_split, net_change_inactive_split_blocks, stat_types); + update_stat_array( + stats.inactive_split_bytes, net_change_inactive_split_size, stat_types); update_stat_array(stats.active, -1, stat_types); update_stat_array(stats.active_bytes, -original_block_size, stat_types); } - /** combine previously split blocks. returns the size of the subsumed block, or 0 on failure. */ - size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) - { - if (!src || src->allocated || src->event_count > 0 || !src->stream_uses.empty()) { + /** combine previously split blocks. returns the size of the subsumed block, + * or 0 on failure. */ + size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { + if (!src || src->allocated || src->event_count > 0 || + !src->stream_uses.empty()) { return 0; } @@ -776,15 +868,18 @@ class DeviceCachingAllocator { BlockPool& get_pool(size_t size, cudaStream_t stream) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - // captures_underway is a conservative guess that the current stream may be capturing. - // It's only > 0 if some thread has begun and not yet ended a capture, so it's usually 0, - // and we can short-circuit cudaStreamCaptureStatus (which does a TLS lookup). + // captures_underway is a conservative guess that the current stream may be + // capturing. It's only > 0 if some thread has begun and not yet ended a + // capture, so it's usually 0, and we can short-circuit + // cudaStreamCaptureStatus (which does a TLS lookup). if (C10_UNLIKELY(captures_underway)) { CaptureId_t id; cudaStreamCaptureStatus status; C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id)); if (status != cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { - TORCH_INTERNAL_ASSERT(status != cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated); + TORCH_INTERNAL_ASSERT( + status != + cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated); // Retrieves the private pool assigned to this capture. auto it0 = capture_to_pool_map.find(id); TORCH_INTERNAL_ASSERT(it0 != capture_to_pool_map.end()); @@ -811,9 +906,8 @@ class DeviceCachingAllocator { bool should_split(const Block* block, size_t size) { size_t remaining = block->size - size; - return (block->pool->is_small) ? - (remaining >= kMinBlockSize) : - (remaining > kSmallSize); + return (block->pool->is_small) ? (remaining >= kMinBlockSize) + : (remaining > kSmallSize); } static size_t get_allocation_size(size_t size) { @@ -840,7 +934,7 @@ class DeviceCachingAllocator { bool freed_memory = false; for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) { freed_memory |= - FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute(); + FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute(); } return freed_memory; } @@ -856,22 +950,27 @@ class DeviceCachingAllocator { stats.num_alloc_retries += 1; } - if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) { + if (set_fraction && + total_allocated_memory + size > allowed_memory_maximum) { p.err = cudaErrorMemoryAllocation; return false; } else { p.err = cudaMallocMaybeCapturing(&ptr, size); if (p.err != cudaSuccess) { if (p.err == cudaErrorMemoryAllocation) { - // If this is the first attempt (!isRetry), we can forgive and clear CUDA's + // If this is the first attempt (!isRetry), we can forgive and clear + // CUDA's // internal error state. - // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH will take - // over to throw a helpful exception. The user can choose to catch the exception, - // free some stuff in their script, and attempt their allocation again. - // In this case, we can also forgive and clear CUDA's internal error state. + // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH + // will take + // over to throw a helpful exception. The user can choose to catch + // the exception, free some stuff in their script, and attempt their + // allocation again. In this case, we can also forgive and clear + // CUDA's internal error state. cudaGetLastError(); } else { - // If the error's unrelated to memory allocation, we should throw immediately. + // If the error's unrelated to memory allocation, we should throw + // immediately. C10_CUDA_CHECK(p.err); } return false; @@ -893,8 +992,7 @@ class DeviceCachingAllocator { return true; } - bool free_cached_blocks() - { + bool free_cached_blocks() { // First ensure that all blocks that can't currently be allocated due to // outstanding events are returned to the pool. synchronize_and_free_events(); @@ -903,7 +1001,8 @@ class DeviceCachingAllocator { free_blocks(large_blocks); free_blocks(small_blocks); - for (auto it = graph_pools_freeable.begin(); it != graph_pools_freeable.end(); ) { + for (auto it = graph_pools_freeable.begin(); + it != graph_pools_freeable.end();) { // See notifyCaptureDestroy for the strategy here. TORCH_INTERNAL_ASSERT(it->second->use_count == 0); free_blocks(it->second->small_blocks); @@ -920,8 +1019,7 @@ class DeviceCachingAllocator { return true; } - void free_blocks(BlockPool& pool) - { + void free_blocks(BlockPool& pool) { // Frees all non-split blocks auto it = pool.blocks.begin(); while (it != pool.blocks.end()) { @@ -986,8 +1084,7 @@ class DeviceCachingAllocator { cuda_events.clear(); } - void insert_events(Block* block) - { + void insert_events(Block* block) { int prev_device; C10_CUDA_CHECK(cudaGetDevice(&prev_device)); @@ -1016,8 +1113,7 @@ class DeviceCachingAllocator { } } - void process_events() - { + void process_events() { insert_events_deferred_until_no_capture(); // Process outstanding cudaEvents. Events that are completed are removed @@ -1050,8 +1146,7 @@ class DeviceCachingAllocator { } // Accumulates sizes of all memory blocks for given device in given pool - void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) - { + void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) { for (const auto& block : pool.blocks) { size_t blocksize = block->size; *total += blocksize; @@ -1063,9 +1158,7 @@ class DeviceCachingAllocator { }; class THCCachingAllocator { - private: - std::mutex mutex; // allocated blocks by device pointer @@ -1080,14 +1173,13 @@ class THCCachingAllocator { } public: - std::vector> device_allocator; std::mutex* getCudaFreeMutex() const { return &cuda_free_mutex; } - Block* get_allocated_block(void *ptr, bool remove=false) { + Block* get_allocated_block(void* ptr, bool remove = false) { std::lock_guard lock(mutex); auto it = allocated_blocks.find(ptr); if (it == allocated_blocks.end()) { @@ -1105,7 +1197,8 @@ class THCCachingAllocator { if (size < device_count) { device_allocator.resize(device_count); for (int i = size; i < device_count; i++) { - device_allocator[i] = std::unique_ptr(new DeviceCachingAllocator()); + device_allocator[i] = std::unique_ptr( + new DeviceCachingAllocator()); } } } @@ -1140,14 +1233,14 @@ class THCCachingAllocator { device, ": did you call init?"); TORCH_INTERNAL_ASSERT( - 0 <= fraction && fraction <= 1, + 0 <= fraction && fraction <= 1, "invalid fraction:", fraction, ". Please set within (0, 1)."); int activated_device; - cudaGetDevice (&activated_device); + cudaGetDevice(&activated_device); if (activated_device != device) { - cudaSetDevice(device); + cudaSetDevice(device); } device_allocator[device]->setMemoryFraction(fraction); } @@ -1158,8 +1251,7 @@ class THCCachingAllocator { device_allocator[i]->emptyCache(); } - void* getBaseAllocation(void* ptr, size_t* outSize) - { + void* getBaseAllocation(void* ptr, size_t* outSize) { Block* block = get_allocated_block(ptr); if (!block) { TORCH_CHECK(false, "invalid device pointer: ", ptr); @@ -1230,7 +1322,8 @@ struct CudaCachingAllocator : public Allocator { return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)}; } if (size != 0) { - caching_allocator.malloc(&r, device, size, cuda::getCurrentCUDAStream(device)); + caching_allocator.malloc( + &r, device, size, cuda::getCurrentCUDAStream(device)); } return {r, r, &raw_delete, Device(DeviceType::CUDA, device)}; } @@ -1245,8 +1338,7 @@ struct CudaCachingAllocator : public Allocator { CudaCachingAllocator device_allocator; -Allocator* get(void) -{ +Allocator* get(void) { return &device_allocator; } @@ -1263,21 +1355,19 @@ void emptyCache(void) { } void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { - caching_allocator.device_allocator[dev_id]->cacheInfo(cachedAndFree, largestBlock); + caching_allocator.device_allocator[dev_id]->cacheInfo( + cachedAndFree, largestBlock); } -void* getBaseAllocation(void *ptr, size_t *size) -{ +void* getBaseAllocation(void* ptr, size_t* size) { return caching_allocator.getBaseAllocation(ptr, size); } -void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) -{ +void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { caching_allocator.recordStream(ptr, stream); } -std::mutex* getFreeMutex() -{ +std::mutex* getFreeMutex() { return caching_allocator.getCudaFreeMutex(); } @@ -1306,11 +1396,13 @@ std::vector snapshot() { } // CUDAGraph interactions -void notifyCaptureBegin(int device, - CaptureId_t graph_id, - MempoolId_t mempool_id) { +void notifyCaptureBegin( + int device, + CaptureId_t graph_id, + MempoolId_t mempool_id) { assertValidDevice(device); - caching_allocator.device_allocator[device]->notifyCaptureBegin(graph_id, mempool_id); + caching_allocator.device_allocator[device]->notifyCaptureBegin( + graph_id, mempool_id); } void notifyCaptureEnd(int device, CaptureId_t graph_id) { @@ -1328,21 +1420,21 @@ void notifyCaptureDestroy(int device, MempoolId_t mempool_id) { // is called by the receiving process to map the CUDA memory from the sending // process into its own address space. // -// CUDA IPC only allows sharing a big memory block associated with a cudaIpcMemHandle_t -// and it can be opened only **once** per context per process. There can be -// multiple types of storage in the same IPC mem block, so we must cache the -// device ptr to construct typed storage as it comes. +// CUDA IPC only allows sharing a big memory block associated with a +// cudaIpcMemHandle_t and it can be opened only **once** per context per +// process. There can be multiple types of storage in the same IPC mem block, so +// we must cache the device ptr to construct typed storage as it comes. // -// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the process -// that can be used to access the memory block in the sender process. -// It only saves a weak_ptr of the device pointer in the map, the shared_ptr -// will be used to reconstruct all storages in this CudaMalloc allocation. -// And it will deleted in cudaIpcCloseMemHandle when its reference count is 0. +// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the +// process that can be used to access the memory block in the sender process. It +// only saves a weak_ptr of the device pointer in the map, the shared_ptr will +// be used to reconstruct all storages in this CudaMalloc allocation. And it +// will deleted in cudaIpcCloseMemHandle when its reference count is 0. // namespace { - std::mutex IpcMutex; - std::unordered_map> ipcMemHandle_to_devptr; -} +std::mutex IpcMutex; +std::unordered_map> ipcMemHandle_to_devptr; +} // namespace std::shared_ptr getIpcDevPtr(std::string handle) { std::lock_guard lock(IpcMutex); @@ -1350,23 +1442,24 @@ std::shared_ptr getIpcDevPtr(std::string handle) { auto iter = ipcMemHandle_to_devptr.find(handle); if (iter != ipcMemHandle_to_devptr.end()) { auto devptr = iter->second.lock(); - if (devptr) return devptr; + if (devptr) + return devptr; } // This ipcMemHandle hasn't been opened, or already expired, open it to // enable IPC access to that mem block. - void *dev = nullptr; + void* dev = nullptr; auto ipc_handle = reinterpret_cast(handle.c_str()); - C10_CUDA_CHECK(cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); + C10_CUDA_CHECK( + cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); // devPtr has to be deleted in same device when created. int curr_device; C10_CUDA_CHECK(cudaGetDevice(&curr_device)); - auto sp = std::shared_ptr( - dev, - [handle, curr_device](void *ptr) { - cuda::CUDAGuard device_guard(curr_device); - std::lock_guard deleter_lock(IpcMutex); - C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); - ipcMemHandle_to_devptr.erase(handle);}); + auto sp = std::shared_ptr(dev, [handle, curr_device](void* ptr) { + cuda::CUDAGuard device_guard(curr_device); + std::lock_guard deleter_lock(IpcMutex); + C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + ipcMemHandle_to_devptr.erase(handle); + }); std::weak_ptr wp = sp; // To eliminate an additional search, we can use insert(). // It doesn't overwrite when key already exists(ptr expired). @@ -1384,7 +1477,8 @@ void* raw_alloc(size_t nbytes) { int device; C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; - caching_allocator.malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); + caching_allocator.malloc( + &r, device, nbytes, cuda::getCurrentCUDAStream(device)); return r; } @@ -1405,4 +1499,5 @@ void raw_delete(void* ptr) { } // namespace CUDACachingAllocator -}} // namespace c10::cuda +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 4b28b1ad40e..51f2c87e628 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,9 +1,9 @@ #ifndef THC_DEVICE_ALLOCATOR_INC #define THC_DEVICE_ALLOCATOR_INC -#include -#include #include +#include #include +#include #include #include @@ -19,7 +19,7 @@ class C10_CUDA_API CUDAOutOfMemoryError : public c10::Error { // block inside of already allocated area. class C10_CUDA_API FreeMemoryCallback { public: - virtual ~FreeMemoryCallback() {}; + virtual ~FreeMemoryCallback(){}; virtual bool Execute() = 0; }; @@ -55,7 +55,7 @@ enum struct StatType : uint64_t { AGGREGATE = 0, SMALL_POOL = 1, LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added + NUM_TYPES = 3 // remember to update this whenever a new stat type is added }; typedef std::array(StatType::NUM_TYPES)> StatArray; @@ -68,7 +68,8 @@ struct DeviceStats { StatArray segment; // COUNT: number of active memory blocks (allocated or used by stream) StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be released via cudaFree) + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via cudaFree) StatArray inactive_split; // SUM: bytes requested by client code @@ -80,14 +81,16 @@ struct DeviceStats { // SUM: bytes within inactive, split memory blocks StatArray inactive_split_bytes; - // COUNT: total number of failed calls to CUDA malloc necessitating cache flushes. + // COUNT: total number of failed calls to CUDA malloc necessitating cache + // flushes. int64_t num_alloc_retries = 0; // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) int64_t num_ooms = 0; }; -// Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc).. +// Struct containing info of an allocation block (i.e. a fractional part of a +// cudaMalloc).. struct BlockInfo { int64_t size = 0; bool allocated = false; @@ -113,8 +116,11 @@ C10_CUDA_API Allocator* get(); C10_CUDA_API void init(int device_count); C10_CUDA_API void setMemoryFraction(double fraction, int device); C10_CUDA_API void emptyCache(); -C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock); -C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size); +C10_CUDA_API void cacheInfo( + int dev_id, + size_t* cachedAndFree, + size_t* largestBlock); +C10_CUDA_API void* getBaseAllocation(void* ptr, size_t* size); C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream); C10_CUDA_API DeviceStats getDeviceStats(int device); C10_CUDA_API void resetAccumulatedStats(int device); @@ -122,9 +128,10 @@ C10_CUDA_API void resetPeakStats(int device); C10_CUDA_API std::vector snapshot(); // CUDAGraph interactions -C10_CUDA_API void notifyCaptureBegin(int device, - CaptureId_t graph_id, - MempoolId_t mempool_id); +C10_CUDA_API void notifyCaptureBegin( + int device, + CaptureId_t graph_id, + MempoolId_t mempool_id); C10_CUDA_API void notifyCaptureEnd(int device, CaptureId_t graph_id); C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id); @@ -133,6 +140,7 @@ C10_CUDA_API std::mutex* getFreeMutex(); C10_CUDA_API std::shared_ptr getIpcDevPtr(std::string handle); } // namespace CUDACachingAllocator -}} // namespace c10::cuda +} // namespace cuda +} // namespace c10 #endif diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index 5d1a473b559..d438e4eb816 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include // Note [CHECK macro] @@ -21,7 +21,7 @@ } \ } while (0) - #define C10_CUDA_CHECK_WARN(EXPR) \ +#define C10_CUDA_CHECK_WARN(EXPR) \ do { \ cudaError_t __err = EXPR; \ if (__err != cudaSuccess) { \ diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 9a68c3e2246..16730d81de3 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -101,7 +101,9 @@ DeviceIndex device_count() noexcept { static int count = []() { try { auto result = device_count_impl(/*fail_if_no_driver=*/false); - TORCH_INTERNAL_ASSERT(result <= std::numeric_limits::max(), "Too many CUDA devices, DeviceIndex overflowed"); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many CUDA devices, DeviceIndex overflowed"); return result; } catch (const c10::Error& ex) { // We don't want to fail, but still log the warning diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index 2239d3b0292..79d727feeb1 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -27,7 +27,7 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard { C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); } - private: + private: cudaStreamCaptureMode strictness_; }; #endif @@ -35,15 +35,18 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 // Protects against enum cudaStreamCaptureStatus implementation changes. // Some compilers seem not to like static_assert without the messages. -static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, - "unexpected int(cudaStreamCaptureStatusNone) value"); -static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, - "unexpected int(cudaStreamCaptureStatusActive) value"); -static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, - "unexpected int(cudaStreamCaptureStatusInvalidated) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, + "unexpected int(cudaStreamCaptureStatusNone) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, + "unexpected int(cudaStreamCaptureStatusActive) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, + "unexpected int(cudaStreamCaptureStatusInvalidated) value"); #endif -enum class CaptureStatus: int { +enum class CaptureStatus : int { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), @@ -54,7 +57,7 @@ enum class CaptureStatus: int { }; inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { - switch(status) { + switch (status) { case CaptureStatus::None: os << "cudaStreamCaptureStatusNone"; break; @@ -67,9 +70,8 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { break; #endif default: - TORCH_INTERNAL_ASSERT(false, - "Unknown CUDA graph CaptureStatus", - int(status)); + TORCH_INTERNAL_ASSERT( + false, "Unknown CUDA graph CaptureStatus", int(status)); } return os; } @@ -78,13 +80,13 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 cudaStreamCaptureStatus is_capturing; - C10_CUDA_CHECK(cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), - &is_capturing)); + C10_CUDA_CHECK( + cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); return CaptureStatus(is_capturing); #else return CaptureStatus::None; #endif } -} // namespace c10 } // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDAGuard.h b/c10/cuda/CUDAGuard.h index f5ec3434388..905dcf9c6ff 100644 --- a/c10/cuda/CUDAGuard.h +++ b/c10/cuda/CUDAGuard.h @@ -1,16 +1,18 @@ #pragma once -#include -#include #include #include #include +#include +#include #include -namespace c10 { namespace cuda { +namespace c10 { +namespace cuda { -// This code is kind of boilerplatey. See Note [Whither the DeviceGuard boilerplate] +// This code is kind of boilerplatey. See Note [Whither the DeviceGuard +// boilerplate] /// A variant of DeviceGuard that is specialized for CUDA. It accepts /// integer indices (interpreting them as CUDA devices) and is a little @@ -38,22 +40,32 @@ struct CUDAGuard { /// Sets the CUDA device to the given device. Errors if the given device /// is not a CUDA device. - void set_device(Device device) { guard_.set_device(device); } + void set_device(Device device) { + guard_.set_device(device); + } /// Sets the CUDA device to the given device. Errors if the given device /// is not a CUDA device. (This method is provided for uniformity with /// DeviceGuard). - void reset_device(Device device) { guard_.reset_device(device); } + void reset_device(Device device) { + guard_.reset_device(device); + } /// Sets the CUDA device to the given device index. - void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } /// Returns the device that was set upon construction of the guard - Device original_device() const { return guard_.original_device(); } + Device original_device() const { + return guard_.original_device(); + } - /// Returns the last device that was set via `set_device`, if any, otherwise the - /// device passed during construction. - Device current_device() const { return guard_.current_device(); } + /// Returns the last device that was set via `set_device`, if any, otherwise + /// the device passed during construction. + Device current_device() const { + return guard_.current_device(); + } private: /// The guard for the current device. @@ -67,11 +79,13 @@ struct OptionalCUDAGuard { explicit OptionalCUDAGuard() : guard_() {} /// Set the current CUDA device to the passed Device, if it is not nullopt. - explicit OptionalCUDAGuard(optional device_opt) : guard_(device_opt) {} + explicit OptionalCUDAGuard(optional device_opt) + : guard_(device_opt) {} /// Set the current CUDA device to the passed device index, if it is not /// nullopt - explicit OptionalCUDAGuard(optional device_index_opt) : guard_(device_index_opt) {} + explicit OptionalCUDAGuard(optional device_index_opt) + : guard_(device_index_opt) {} // Copy is not allowed OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; @@ -84,31 +98,45 @@ struct OptionalCUDAGuard { OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; /// Sets the CUDA device to the given device, initializing the guard if it - /// is not already initialized. Errors if the given device is not a CUDA device. - void set_device(Device device) { guard_.set_device(device); } + /// is not already initialized. Errors if the given device is not a CUDA + /// device. + void set_device(Device device) { + guard_.set_device(device); + } /// Sets the CUDA device to the given device, initializing the guard if it is /// not already initialized. Errors if the given device is not a CUDA device. /// (This method is provided for uniformity with OptionalDeviceGuard). - void reset_device(Device device) { guard_.reset_device(device); } + void reset_device(Device device) { + guard_.reset_device(device); + } /// Sets the CUDA device to the given device index, initializing the guard if /// it is not already initialized. - void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } /// Returns the device that was set immediately prior to initialization of the /// guard, or nullopt if the guard is uninitialized. - optional original_device() const { return guard_.original_device(); } + optional original_device() const { + return guard_.original_device(); + } /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device, if the guard is initialized, /// or nullopt if the guard is uninitialized. - optional current_device() const { return guard_.current_device(); } + optional current_device() const { + return guard_.current_device(); + } - /// Restore the original CUDA device, resetting this guard to uninitialized state. - void reset() { guard_.reset(); } + /// Restore the original CUDA device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } -private: + private: c10::impl::InlineOptionalDeviceGuard guard_; }; @@ -118,17 +146,17 @@ struct CUDAStreamGuard { /// No default constructor, see Note [Omitted default constructor from RAII] explicit CUDAStreamGuard() = delete; - /// Set the current CUDA device to the device associated with the passed stream, - /// and set the current CUDA stream on that device to the passed stream. - /// Errors if the Stream is not a CUDA stream. + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} /// Copy is disallowed CUDAStreamGuard(const CUDAStreamGuard&) = delete; CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; - /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized state, - /// which is required for moves on types with nontrivial destructors. + /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized + /// state, which is required for moves on types with nontrivial destructors. CUDAStreamGuard(CUDAStreamGuard&& other) = delete; CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; @@ -144,9 +172,12 @@ struct CUDAStreamGuard { /// WARNING: reset_stream does NOT preserve previously set streams on /// different devices. If you need to set streams on multiple devices /// on CUDA, use CUDAMultiStreamGuard instead. - void reset_stream(Stream stream) { guard_.reset_stream(stream); } + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } - /// Returns the CUDA stream that was set at the time the guard was constructed. + /// Returns the CUDA stream that was set at the time the guard was + /// constructed. CUDAStream original_stream() const { return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); } @@ -159,31 +190,36 @@ struct CUDAStreamGuard { /// Returns the most recent CUDA device that was set using this device guard, /// either from construction, or via set_device/reset_device/set_index. - Device current_device() const { return guard_.current_device(); } + Device current_device() const { + return guard_.current_device(); + } /// Returns the CUDA device that was set at the most recent reset_stream(), /// or otherwise the device at construction time. - Device original_device() const { return guard_.original_device(); } + Device original_device() const { + return guard_.original_device(); + } -private: + private: c10::impl::InlineStreamGuard guard_; }; -/// A variant of OptionalStreamGuard that is specialized for CUDA. See CUDAGuard -/// for when you can use this. +/// A variant of OptionalStreamGuard that is specialized for CUDA. See +/// CUDAGuard for when you can use this. struct OptionalCUDAStreamGuard { /// Create an uninitialized guard. explicit OptionalCUDAStreamGuard() : guard_() {} - /// Set the current CUDA device to the device associated with the passed stream, - /// and set the current CUDA stream on that device to the passed stream. - /// Errors if the Stream is not a CUDA stream. + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, /// if the passed stream is not nullopt. - explicit OptionalCUDAStreamGuard(optional stream_opt) : guard_(stream_opt) {} + explicit OptionalCUDAStreamGuard(optional stream_opt) + : guard_(stream_opt) {} /// Copy is disallowed OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; @@ -200,10 +236,12 @@ struct OptionalCUDAStreamGuard { /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// Initializes the guard if it was not previously initialized. - void reset_stream(Stream stream) { guard_.reset_stream(stream); } + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } - /// Returns the CUDA stream that was set at the time the guard was most recently - /// initialized, or nullopt if the guard is uninitialized. + /// Returns the CUDA stream that was set at the time the guard was most + /// recently initialized, or nullopt if the guard is uninitialized. optional original_stream() const { auto r = guard_.original_stream(); if (r.has_value()) { @@ -214,8 +252,8 @@ struct OptionalCUDAStreamGuard { } /// Returns the most recent CUDA stream that was set using this stream guard, - /// either from construction, or via reset_stream, if the guard is initialized, - /// or nullopt if the guard is uninitialized. + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. optional current_stream() const { auto r = guard_.current_stream(); if (r.has_value()) { @@ -225,17 +263,20 @@ struct OptionalCUDAStreamGuard { } } - /// Restore the original CUDA device and stream, resetting this guard to uninitialized state. - void reset() { guard_.reset(); } + /// Restore the original CUDA device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } -private: + private: c10::impl::InlineOptionalStreamGuard guard_; }; /// A variant of MultiStreamGuard that is specialized for CUDA. struct CUDAMultiStreamGuard { explicit CUDAMultiStreamGuard(ArrayRef streams) - : guard_(unwrapStreams(streams)) {} + : guard_(unwrapStreams(streams)) {} /// Copy is disallowed CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; @@ -247,7 +288,7 @@ struct CUDAMultiStreamGuard { // See Note [Move assignment for RAII guards is tricky] CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; -private: + private: c10::impl::InlineMultiStreamGuard guard_; static std::vector unwrapStreams(ArrayRef cudaStreams) { diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index a3edfabae9f..c32580bf799 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -204,7 +204,7 @@ static void initGlobalStreamState() { "). Increase that and recompile."); // Initializes default streams - for (const auto i: c10::irange(num_gpus)) { + for (const auto i : c10::irange(num_gpus)) { default_streams[i].device_index = i; low_priority_counters[i] = 0; high_priority_counters[i] = 0; @@ -218,7 +218,7 @@ static void initDeviceStreamState(DeviceIndex device_index) { // with it. CUDAGuard device_guard{device_index}; - for (const auto i: c10::irange(kStreamsPerPool)) { + for (const auto i : c10::irange(kStreamsPerPool)) { auto& lowpri_stream = low_priority_streams[device_index][i]; auto& hipri_stream = high_priority_streams[device_index][i]; @@ -244,7 +244,7 @@ static void initCUDAStreamsOnce() { // Inits current streams (thread local) to default streams current_streams = (LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*)); - for (const auto i: c10::irange(num_gpus)) { + for (const auto i : c10::irange(num_gpus)) { current_streams[i] = &default_streams[i]; } } diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index 05eddf5ce12..2e00ecc4a02 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -5,53 +5,53 @@ #include +#include +#include #include #include -#include #include -#include /* -* Stream pool note. -* -* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams -* are backed by cuStreams, but they use several pools to minimize the costs -* associated with creating, retaining, and destroying cuStreams. -* -* There are three pools per device, and a device's pools are lazily created. -* -* The first pool contains only the default stream. When the default stream -* is requested it's returned. -* -* The second pool is the "low priority" or "default priority" streams. In -* HIP builds there is no distinction between streams in this pool and streams -* in the third pool (below). There are 32 of these streams per device, and -* when a stream is requested one of these streams is returned round-robin. -* That is, the first stream requested is at index 0, the second at index 1... -* to index 31, then index 0 again. -* -* This means that if 33 low priority streams are requested, the first and -* last streams requested are actually the same stream (under the covers) -* and kernels enqueued on them cannot run concurrently. -* -* The third pool is the "high priority" streams. The third pool acts like -* the second pool except the streams are created with a higher priority. -* -* These pools suggest that stream users should prefer many short-lived streams, -* as the cost of acquiring and releasing streams is effectively zero. If -* many longer-lived streams are required in performance critical scenarios -* then the functionality here may need to be extended to allow, for example, -* "reserving" a subset of the pool so that other streams do not accidentally -* overlap the performance critical streams. -* -* Note: although the notion of "current stream for device" is thread local -* (every OS thread has a separate current stream, as one might expect), -* the stream pool is global across all threads; stream 0 is always stream 0 -* no matter which thread you use it on. Multiple threads can synchronize -* on the same stream. Although the CUDA documentation is not very clear -* on the matter, streams are thread safe; e.g., it is safe to enqueue -* a kernel on the same stream from two different threads. -*/ + * Stream pool note. + * + * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams + * are backed by cuStreams, but they use several pools to minimize the costs + * associated with creating, retaining, and destroying cuStreams. + * + * There are three pools per device, and a device's pools are lazily created. + * + * The first pool contains only the default stream. When the default stream + * is requested it's returned. + * + * The second pool is the "low priority" or "default priority" streams. In + * HIP builds there is no distinction between streams in this pool and streams + * in the third pool (below). There are 32 of these streams per device, and + * when a stream is requested one of these streams is returned round-robin. + * That is, the first stream requested is at index 0, the second at index 1... + * to index 31, then index 0 again. + * + * This means that if 33 low priority streams are requested, the first and + * last streams requested are actually the same stream (under the covers) + * and kernels enqueued on them cannot run concurrently. + * + * The third pool is the "high priority" streams. The third pool acts like + * the second pool except the streams are created with a higher priority. + * + * These pools suggest that stream users should prefer many short-lived streams, + * as the cost of acquiring and releasing streams is effectively zero. If + * many longer-lived streams are required in performance critical scenarios + * then the functionality here may need to be extended to allow, for example, + * "reserving" a subset of the pool so that other streams do not accidentally + * overlap the performance critical streams. + * + * Note: although the notion of "current stream for device" is thread local + * (every OS thread has a separate current stream, as one might expect), + * the stream pool is global across all threads; stream 0 is always stream 0 + * no matter which thread you use it on. Multiple threads can synchronize + * on the same stream. Although the CUDA documentation is not very clear + * on the matter, streams are thread safe; e.g., it is safe to enqueue + * a kernel on the same stream from two different threads. + */ namespace c10 { namespace cuda { @@ -61,8 +61,7 @@ namespace cuda { // functionality (conversion to cudaStream_t), and a guarantee that // the wrapped c10::Stream really is a CUDA stream. class C10_CUDA_API CUDAStream { -public: - + public: enum Unchecked { UNCHECKED }; /// Construct a CUDAStream from a Stream. This construction is checked, @@ -85,21 +84,31 @@ public: } /// Implicit conversion to cudaStream_t. - operator cudaStream_t() const { return stream(); } + operator cudaStream_t() const { + return stream(); + } /// Implicit conversion to Stream (a.k.a., forget that the stream is a /// CUDA stream). - operator Stream() const { return unwrap(); } + operator Stream() const { + return unwrap(); + } /// Get the CUDA device index that this stream is associated with. - DeviceIndex device_index() const { return stream_.device_index(); } + DeviceIndex device_index() const { + return stream_.device_index(); + } /// Get the full Device that this stream is associated with. The Device /// is guaranteed to be a CUDA device. - Device device() const { return Device(DeviceType::CUDA, device_index()); } + Device device() const { + return Device(DeviceType::CUDA, device_index()); + } /// Return the stream ID corresponding to this particular stream. - StreamId id() const { return stream_.id(); } + StreamId id() const { + return stream_.id(); + } bool query() const { DeviceGuard guard{stream_.device()}; @@ -120,17 +129,19 @@ public: } int priority() const { - DeviceGuard guard{stream_.device()}; - int priority = 0; - C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); - return priority; + DeviceGuard guard{stream_.device()}; + int priority = 0; + C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); + return priority; } /// Explicit conversion to cudaStream_t. cudaStream_t stream() const; /// Explicit conversion to Stream. - Stream unwrap() const { return stream_; } + Stream unwrap() const { + return stream_; + } /// Reversibly pack a CUDAStream into a uint64_t representation. This may /// be helpful when storing a CUDAStream in a C struct, where you cannot @@ -150,22 +161,24 @@ public: } static std::tuple priority_range() { - // Note: this returns the range of priority **supported by PyTorch**, not - // the range of priority **supported by CUDA**. The former is a subset of - // the latter. Currently PyTorch only supports 0 and -1, which are "low" and - // "high" priority. - int least_priority, greatest_priority; - C10_CUDA_CHECK( + // Note: this returns the range of priority **supported by PyTorch**, not + // the range of priority **supported by CUDA**. The former is a subset of + // the latter. Currently PyTorch only supports 0 and -1, which are "low" and + // "high" priority. + int least_priority, greatest_priority; + C10_CUDA_CHECK( cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); - TORCH_INTERNAL_ASSERT(least_priority >= 0, "Unexpected CUDA stream priority range"); - TORCH_INTERNAL_ASSERT(greatest_priority <= -1, "Unexpected CUDA stream priority range"); - return std::make_tuple(0, -1); + TORCH_INTERNAL_ASSERT( + least_priority >= 0, "Unexpected CUDA stream priority range"); + TORCH_INTERNAL_ASSERT( + greatest_priority <= -1, "Unexpected CUDA stream priority range"); + return std::make_tuple(0, -1); } // Deleted for now; use CUDAEvent::block instead // void synchronize_with(const CUDAEvent& event) const; -private: + private: Stream stream_; }; @@ -214,13 +227,13 @@ TORCH_API void setCurrentCUDAStream(CUDAStream stream); C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); } // namespace cuda -} // namespace at +} // namespace c10 namespace std { - template <> - struct hash { - size_t operator()(c10::cuda::CUDAStream s) const noexcept { - return std::hash{}(s.unwrap()); - } - }; +template <> +struct hash { + size_t operator()(c10::cuda::CUDAStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; } // namespace std diff --git a/c10/cuda/impl/CUDAGuardImpl.cpp b/c10/cuda/impl/CUDAGuardImpl.cpp index b0be6791bd0..6bed0c0c69d 100644 --- a/c10/cuda/impl/CUDAGuardImpl.cpp +++ b/c10/cuda/impl/CUDAGuardImpl.cpp @@ -8,4 +8,6 @@ constexpr DeviceType CUDAGuardImpl::static_type; C10_REGISTER_GUARD_IMPL(CUDA, CUDAGuardImpl); -}}} // namespace c10::cuda::detail +} // namespace impl +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 02a8596c7ba..e938fb0101f 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -4,10 +4,10 @@ #include #include -#include -#include -#include #include +#include +#include +#include #include @@ -82,9 +82,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { } // Event-related functions - void createEvent( - cudaEvent_t* cuda_event, - const EventFlag flag) const { + void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { // Maps PyTorch's Event::Flag to CUDA flag auto cuda_flag = cudaEventDefault; switch (flag) { @@ -103,10 +101,10 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); } - void destroyEvent( - void* event, - const DeviceIndex device_index) const noexcept override { - if (!event) return; + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + if (!event) + return; auto cuda_event = static_cast(event); int orig_device; C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); @@ -116,16 +114,17 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { } void record( - void** event, - const Stream& stream, - const DeviceIndex device_index, - const EventFlag flag) const override { - TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), - "Event device index ", - device_index, - " does not match recording stream's device index ", - stream.device_index(), - "."); + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); cudaEvent_t cuda_event = static_cast(*event); CUDAStream cuda_stream{stream}; @@ -135,7 +134,8 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { setDevice(stream.device()); // Creates the event (lazily) - if (!cuda_event) createEvent(&cuda_event, flag); + if (!cuda_event) + createEvent(&cuda_event, flag); C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); // Makes the void* point to the (possibly just allocated) CUDA event *event = cuda_event; @@ -144,24 +144,24 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { setDevice(orig_device); } - void block( - void* event, - const Stream& stream) const override { - if (!event) return; + void block(void* event, const Stream& stream) const override { + if (!event) + return; cudaEvent_t cuda_event = static_cast(event); CUDAStream cuda_stream{stream}; const auto orig_device = getDevice(); setDevice(stream.device()); C10_CUDA_CHECK(cudaStreamWaitEvent( - cuda_stream, - cuda_event, - /*flags (must be zero)=*/ 0)); + cuda_stream, + cuda_event, + /*flags (must be zero)=*/0)); setDevice(orig_device); } // May be called from any device bool queryEvent(void* event) const override { - if (!event) return true; + if (!event) + return true; cudaEvent_t cuda_event = static_cast(event); const cudaError_t err = cudaEventQuery(cuda_event); if (err != cudaErrorNotReady) { @@ -170,12 +170,13 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { return (err == cudaSuccess); } - void recordDataPtrOnStream( - const c10::DataPtr& data_ptr, - const Stream& stream) const override { + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { CUDAStream cuda_stream{stream}; CUDACachingAllocator::recordStream(data_ptr, cuda_stream); } }; -}}} // namespace c10::cuda::impl +} // namespace impl +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/impl/CUDATest.cpp b/c10/cuda/impl/CUDATest.cpp index 3746d14ae51..fb58d1c3a0f 100644 --- a/c10/cuda/impl/CUDATest.cpp +++ b/c10/cuda/impl/CUDATest.cpp @@ -29,4 +29,6 @@ int c10_cuda_private_test() { return 2; } -}}} // namespace c10::cuda::impl +} // namespace impl +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/impl/CUDATest.h b/c10/cuda/impl/CUDATest.h index ccfe38e0209..593905d1567 100644 --- a/c10/cuda/impl/CUDATest.h +++ b/c10/cuda/impl/CUDATest.h @@ -8,4 +8,6 @@ namespace impl { C10_CUDA_API int c10_cuda_test(); -}}} /// namespace c10::cuda::impl +} +} // namespace cuda +} // namespace c10 diff --git a/c10/macros/Export.h b/c10/macros/Export.h index 7319081ef34..b439e74b37e 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -102,17 +102,17 @@ // two pieces with confusing names? // Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we // tried to compile PyTorch for CUDA 11.1, which ran into relocation marker -// issues when linking big binaries. (https://github.com/pytorch/pytorch/issues/39968) -// We had two choices: +// issues when linking big binaries. +// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: // (1) Stop supporting so many GPU architectures // (2) Do something else // We chose #2 and decided to split the behemoth that was torch_cuda into two // smaller libraries, one with most of the core kernel functions (torch_cuda_cu) -// and the other that had..well..everything else (torch_cuda_cpp). The idea was this: -// instead of linking our static libraries (like the hefty libcudnn_static.a) with -// another huge library, torch_cuda, and run into pesky relocation marker issues, -// we could link our static libraries to a smaller part of torch_cuda (torch_cuda_cpp) -// and avoid the issues. +// and the other that had..well..everything else (torch_cuda_cpp). The idea was +// this: instead of linking our static libraries (like the hefty +// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky +// relocation marker issues, we could link our static libraries to a smaller +// part of torch_cuda (torch_cuda_cpp) and avoid the issues. // libtorch_cuda_cu.so #ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB @@ -128,7 +128,8 @@ #define TORCH_CUDA_CPP_API C10_IMPORT #endif -// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the same api) +// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the +// same api) #ifdef TORCH_CUDA_BUILD_MAIN_LIB #define TORCH_CUDA_CPP_API C10_EXPORT #define TORCH_CUDA_CU_API C10_EXPORT diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 7c1ba643f81..4d869b6aa6b 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -24,16 +24,17 @@ #include #if defined(__clang__) - #define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero"))) - #define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) - #define __ubsan_ignore_signed_int_overflow__ __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) #else - #define __ubsan_ignore_float_divide_by_zero__ - #define __ubsan_ignore_undefined__ - #define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ #endif - // Detect address sanitizer as some stuff doesn't work with it #undef C10_ASAN_ENABLED @@ -57,7 +58,6 @@ #define C10_ASAN_ENABLED 0 #endif - // Disable the copy and assignment operator for a class. Note that this will // disable the usage of the class in std containers. #define C10_DISABLE_COPY_AND_ASSIGN(classname) \ @@ -84,7 +84,6 @@ #define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) #endif - /// C10_NODISCARD - Warn if a type or return value is discarded. // Technically, we should check if __cplusplus > 201402L here, because @@ -108,24 +107,25 @@ // - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support) #define C10_NODISCARD #if defined(__has_cpp_attribute) -# if __has_cpp_attribute(nodiscard) -# undef C10_NODISCARD -# define C10_NODISCARD [[nodiscard]] -# endif +#if __has_cpp_attribute(nodiscard) +#undef C10_NODISCARD +#define C10_NODISCARD [[nodiscard]] +#endif // Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious // error when __has_cpp_attribute is given a scoped attribute in C mode. #elif __cplusplus && defined(__has_cpp_attribute) -# if __has_cpp_attribute(clang::warn_unused_result) -// TODO: It's possible this is still triggering https://github.com/pytorch/pytorch/issues/13118 -// on Windows; if it is, better fix it. -# undef C10_NODISCARD -# define C10_NODISCARD [[clang::warn_unused_result]] -# endif +#if __has_cpp_attribute(clang::warn_unused_result) +// TODO: It's possible this is still triggering +// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better +// fix it. +#undef C10_NODISCARD +#define C10_NODISCARD [[clang::warn_unused_result]] +#endif #endif // suppress an unused variable. #if defined(_MSC_VER) && !defined(__clang__) -#define C10_UNUSED __pragma(warning(suppress: 4100 4101)) +#define C10_UNUSED __pragma(warning(suppress : 4100 4101)) #else #define C10_UNUSED __attribute__((__unused__)) #endif //_MSC_VER @@ -135,16 +135,28 @@ // Simply define the namespace, in case a dependent library want to refer to // the c10 namespace but not any nontrivial files. namespace c10 {} // namespace c10 -namespace c10 { namespace cuda {} } -namespace c10 { namespace hip {} } +namespace c10 { +namespace cuda {} +} // namespace c10 +namespace c10 { +namespace hip {} +} // namespace c10 // Since C10 is the core library for caffe2 (and aten), we will simply reroute // all abstractions defined in c10 to be available in caffe2 as well. // This is only for backwards compatibility. Please use the symbols from the // c10 namespace where possible. -namespace caffe2 { using namespace c10; } -namespace at { using namespace c10; } -namespace at { namespace cuda { using namespace c10::cuda; }} +namespace caffe2 { +using namespace c10; +} +namespace at { +using namespace c10; +} +namespace at { +namespace cuda { +using namespace c10::cuda; +} +} // namespace at // WARNING!!! THIS IS A GIANT HACK!!! // This line means you cannot simultaneously include c10/hip @@ -154,7 +166,11 @@ namespace at { namespace cuda { using namespace c10::cuda; }} // from at::cuda. This namespace makes that happen. When // HIPIFY is no longer out-of-place, we can switch the cuda // here to hip and everyone is happy. -namespace at { namespace cuda { using namespace c10::hip; }} +namespace at { +namespace cuda { +using namespace c10::hip; +} +} // namespace at // C10_LIKELY/C10_UNLIKELY // @@ -169,11 +185,11 @@ namespace at { namespace cuda { using namespace c10::hip; }} // without it. // #if defined(__GNUC__) || defined(__ICL) || defined(__clang__) -#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) -#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) #else -#define C10_LIKELY(expr) (expr) -#define C10_UNLIKELY(expr) (expr) +#define C10_LIKELY(expr) (expr) +#define C10_UNLIKELY(expr) (expr) #endif /// C10_NOINLINE - Functions whose declaration is annotated with this will not @@ -209,11 +225,13 @@ namespace at { namespace cuda { using namespace c10::hip; }} #define C10_HOST_DEVICE __host__ __device__ #define C10_DEVICE __device__ #define C10_HOST __host__ -// constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) -// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5), -// 1536 for Geforce Ampere (8.6), -// and 2048 for all other architectures. You'll get warnings if you exceed these constants. -// Hence, the following macros adjust the input values from the user to resolve potential warnings. +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other +// architectures. You'll get warnings if you exceed these constants. Hence, the +// following macros adjust the input values from the user to resolve potential +// warnings. #if __CUDA_ARCH__ == 750 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; #elif __CUDA_ARCH__ == 860 @@ -223,25 +241,39 @@ constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; #endif // CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; -// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block size. -// 256 is a good number for this fallback and should give good occupancy and -// versatility across all architectures. +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; // NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it // turns out that although __launch_bounds__ can take constexpr, it // can't take a constexpr that has anything to do with templates. // Currently we use launch_bounds that depend on template arguments in -// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK and -// C10_MIN_BLOCKS_PER_SM are kept as macros. -// Suppose you were planning to write __launch_bounds__(a, b), based on your performance tuning on a modern GPU. -// Instead, you should write __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK +// and C10_MIN_BLOCKS_PER_SM are kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), // which will also properly respect limits on old architectures. -#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK) -#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block)))) +#define C10_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / \ + (threads_per_block)))) // C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ -#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures. -#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) -#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) +#define C10_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy and + // versatility across all architectures. +#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) #else #define C10_HOST_DEVICE #define C10_HOST @@ -273,14 +305,12 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #elif defined(_MSC_VER) #if defined(NDEBUG) extern "C" { - C10_IMPORT +C10_IMPORT #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__) - __host__ __device__ +__host__ __device__ #endif // __CUDA_ARCH__ - void _wassert( - wchar_t const* _Message, - wchar_t const* _File, - unsigned _Line); + void + _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); } #endif #define CUDA_KERNEL_ASSERT(cond) \ @@ -304,8 +334,8 @@ __host__ __device__ #endif // NDEBUG #define CUDA_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ - __assert_fail(#cond, __FILE__, static_cast(__LINE__), \ - __func__); \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ } #endif // __APPLE__ diff --git a/c10/mobile/CPUCachingAllocator.h b/c10/mobile/CPUCachingAllocator.h index c80fee0682e..d5f4303e1b3 100644 --- a/c10/mobile/CPUCachingAllocator.h +++ b/c10/mobile/CPUCachingAllocator.h @@ -19,16 +19,15 @@ * See below for more information. * Why? * It has been observed that some mobile platforms, such as pixel 3, return - * memory aggressively to the system. This results in page faults in some cases - * and ends up hurting performance. This caching allocator aims to address that. - * Furthermore it also allows users to specify their own allocator by implementing - * allocate/free virtual interfaces. - * What are the cons? - * There are some cons that were observed where use of caching allocator led to - * worse performance on some platforms. Reason being that the caching mechanism - * used by this allocator left us worse off compared to the corresponding platform's - * tuned memory allocator. In that case it seemed better to not use this allocator. - * Note there are some ideas to fix this in the works. + * memory aggressively to the system. This results in page faults in some + * cases and ends up hurting performance. This caching allocator aims to address + * that. Furthermore it also allows users to specify their own allocator by + * implementing allocate/free virtual interfaces. What are the cons? There are + * some cons that were observed where use of caching allocator led to worse + * performance on some platforms. Reason being that the caching mechanism used + * by this allocator left us worse off compared to the corresponding platform's + * tuned memory allocator. In that case it seemed better to not use this + * allocator. Note there are some ideas to fix this in the works. * * Usage: * Usage pattern: @@ -53,42 +52,44 @@ class C10_API CPUCachingAllocator { * What it does not do: * No speculative allocation for any future allocations. */ - private: - inline void* allocate_and_cache(const size_t bytes); - void free_cached(); - protected: - // Invariants. - // 1. If memory is ever allocated via this allocator then - // the pointer will exist in allocation_map_, unless the allocator - // returned the memory to OS via free_cached. - // 1.1. Therefore even when the said memory is "freed" via this - // allocator (and thus cached), it will continue to stay - // in allocation_map_. Furthermore it will also exist in - // available_map_. Thus an allocated memory pointer can be in both - // allocation_map_ and available_map_ simultaneously. - // 2. Memory pointer maybe removed from allocation_map_, when it - // is freed outside of the scope of this allocator, but was allocated - // by this allocator. - // 3. Available map only contains that memory which was allocated - // by this allocator and subsequently freed by this allocator. - // As a result of above invariants, allocated memory ptr cannot be in - // available_map_ unless it is in allocation_map_ as well. - ska::flat_hash_map> available_map_; - static ska::flat_hash_map allocation_map_; - // Since allocation_map, which is a global instance, is mutated/read via - // all public APIs we need a global mutex. - static std::mutex mutex_; - public: - static void record_free(void* ptr); - virtual ~CPUCachingAllocator(); - // Checks the cache to see if allocation of size bytes can be found. - // If so return cached memory, else - // allocates memory, records it for caching and returns. - virtual void* allocate(const size_t bytes); - // Checks if the memory being freed is was marked for allocation by - // an earlier call to allocate. If so cache the allocation. - // Otherwise free. - virtual void free(void* ptr); + private: + inline void* allocate_and_cache(const size_t bytes); + void free_cached(); + + protected: + // Invariants. + // 1. If memory is ever allocated via this allocator then + // the pointer will exist in allocation_map_, unless the allocator + // returned the memory to OS via free_cached. + // 1.1. Therefore even when the said memory is "freed" via this + // allocator (and thus cached), it will continue to stay + // in allocation_map_. Furthermore it will also exist in + // available_map_. Thus an allocated memory pointer can be in both + // allocation_map_ and available_map_ simultaneously. + // 2. Memory pointer maybe removed from allocation_map_, when it + // is freed outside of the scope of this allocator, but was allocated + // by this allocator. + // 3. Available map only contains that memory which was allocated + // by this allocator and subsequently freed by this allocator. + // As a result of above invariants, allocated memory ptr cannot be in + // available_map_ unless it is in allocation_map_ as well. + ska::flat_hash_map> available_map_; + static ska::flat_hash_map allocation_map_; + // Since allocation_map, which is a global instance, is mutated/read via + // all public APIs we need a global mutex. + static std::mutex mutex_; + + public: + static void record_free(void* ptr); + virtual ~CPUCachingAllocator(); + // Checks the cache to see if allocation of size bytes can be found. + // If so return cached memory, else + // allocates memory, records it for caching and returns. + virtual void* allocate(const size_t bytes); + // Checks if the memory being freed is was marked for allocation by + // an earlier call to allocate. If so cache the allocation. + // Otherwise free. + virtual void free(void* ptr); }; CPUCachingAllocator* GetDefaultCPUCachingAllocator(); @@ -97,11 +98,12 @@ bool ThreadLocalCachingAllocatorEnabled(); CPUCachingAllocator* GetThreadLocalCachingAllocator(); class C10_API WithCPUCachingAllocatorGuard { - public: - WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator); - ~WithCPUCachingAllocatorGuard(); - private: - CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr}; + public: + WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator); + ~WithCPUCachingAllocatorGuard(); + + private: + CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr}; }; } // namespace c10 diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index 5e8eb7df85a..1fc53860f0d 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -19,27 +19,23 @@ struct MemBlock { } }; -enum class EventType { - Allocate = 0, - Free, - Invalid -}; +enum class EventType { Allocate = 0, Free, Invalid }; struct MemEvent { uint64_t time; uint64_t allocation_id; uint64_t size; EventType type{EventType::Invalid}; - MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e) : - time(t), allocation_id(id), size(s), type(e) {} + MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e) + : time(t), allocation_id(id), size(s), type(e) {} }; bool overlaps(const MemBlock& a, const MemBlock& b) { // two blocks dont overlap if // |---a--------|--------------b--------| // strat_a end_a <= start_b end_b - return - !((a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset)); + return !( + (a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset)); } bool validate_allocation_plan( @@ -72,7 +68,8 @@ bool validate_allocation_plan( allocations.emplace(mem_block); } else if (event.type == EventType::Free) { auto it = allocations.find(mem_block); - TORCH_CHECK((*it).end_offset == end_offset, + TORCH_CHECK( + (*it).end_offset == end_offset, "Enf offset of allocation being freed must match the one recorded."); TORCH_CHECK( it != allocations.end(), @@ -98,13 +95,15 @@ std::vector create_and_sort_mem_events( continue; } events.emplace_back(i, i, allocation_sizes[i], EventType::Allocate); - events.emplace_back(allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free); + events.emplace_back( + allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free); } std::sort( events.begin(), events.end(), - [](const MemEvent& a, - const MemEvent& b) -> bool {return a.time < b.time;}); + [](const MemEvent& a, const MemEvent& b) -> bool { + return a.time < b.time; + }); return events; } @@ -132,8 +131,10 @@ std::vector formulate_greedy_allocation_plan( std::map free_size_to_offset; // This provides fast lookup when we want to insert freed block // back, especially when we want to merge blocks. - ska::flat_hash_map::iterator> free_start_offset_to_size_iter; - ska::flat_hash_map::iterator> free_end_offset_to_size_iter; + ska::flat_hash_map::iterator> + free_start_offset_to_size_iter; + ska::flat_hash_map::iterator> + free_end_offset_to_size_iter; // Upon free end_ptr = offset + size // If end_ptr exists merge freed allocation // Also find corresponding offset in size_to_offset @@ -146,7 +147,8 @@ std::vector formulate_greedy_allocation_plan( std::vector allocation_offsets( allocation_sizes.size(), std::numeric_limits::max()); - auto mem_events = create_and_sort_mem_events(allocation_sizes, allocation_lifetimes); + auto mem_events = + create_and_sort_mem_events(allocation_sizes, allocation_lifetimes); uint64_t max_offset{0}; for (const auto& mem_event : mem_events) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -221,13 +223,14 @@ std::vector formulate_greedy_allocation_plan( free_start_offset_to_size_iter.erase(freed_offset); } auto freed_block_it = - free_size_to_offset.emplace(freed_size, freed_offset).first; + free_size_to_offset.emplace(freed_size, freed_offset).first; free_start_offset_to_size_iter.emplace(freed_offset, freed_block_it); free_end_offset_to_size_iter.emplace( freed_offset + freed_size, freed_block_it); } } - TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets), + TORCH_CHECK( + validate_allocation_plan(mem_events, allocation_offsets), "ProfilingAllocator: Allocation plan invalid."); return allocation_offsets; } @@ -241,7 +244,8 @@ void AllocationPlan::clear() { } void AllocationPlanner::record_allocation( - const uint64_t size, const void* ptr) { + const uint64_t size, + const void* ptr) { if (validation_mode_) { validation_success = validation_success && validate_allocation(size, ptr); return; @@ -264,13 +268,15 @@ void AllocationPlanner::record_free(const void* ptr) { return; } auto id = it->second; - TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(), + TORCH_CHECK( + id < allocation_plan_->allocation_lifetimes.size(), "Allocation must have been recorded during record_allocation."); allocation_plan_->allocation_lifetimes[id] = allocation_id_; } bool AllocationPlanner::validate_allocation( - const uint64_t size, const void* ptr) { + const uint64_t size, + const void* ptr) { if (allocation_id_ >= allocation_plan_->allocation_sizes.size() || allocation_plan_->allocation_sizes[allocation_id_] != size) { TORCH_WARN( @@ -286,7 +292,7 @@ bool AllocationPlanner::validate_allocation( return false; } - allocation_ptr_to_id_[ptr] = allocation_id_; + allocation_ptr_to_id_[ptr] = allocation_id_; allocation_id_++; return true; } @@ -298,24 +304,27 @@ bool AllocationPlanner::validate_free(const void* ptr) { return true; } auto id = (*it).second; - TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(), + TORCH_CHECK( + id < allocation_plan_->allocation_lifetimes.size(), "Allocation must have been recorded during validate_allocation."); auto lifetime_id = allocation_plan_->allocation_lifetimes[id]; return (lifetime_id == allocation_id_); } void AllocationPlanner::formulate_plan() { - allocation_plan_->allocation_offsets = - formulate_greedy_allocation_plan( - allocation_plan_->allocation_sizes, allocation_plan_->allocation_lifetimes); + allocation_plan_->allocation_offsets = formulate_greedy_allocation_plan( + allocation_plan_->allocation_sizes, + allocation_plan_->allocation_lifetimes); allocation_plan_->total_size = 0; for (const auto i : c10::irange(allocation_plan_->allocation_sizes.size())) { if (allocation_plan_->allocation_lifetimes[i] == std::numeric_limits::max()) { continue; } - auto limit = allocation_plan_->allocation_offsets[i] + allocation_plan_->allocation_sizes[i]; - allocation_plan_->total_size = std::max(allocation_plan_->total_size, limit); + auto limit = allocation_plan_->allocation_offsets[i] + + allocation_plan_->allocation_sizes[i]; + allocation_plan_->total_size = + std::max(allocation_plan_->total_size, limit); } } @@ -344,7 +353,8 @@ void CPUProfilingAllocator::unset_plan() { } void* CPUProfilingAllocator::allocate(const size_t bytes) { - TORCH_CHECK(bytes == plan_->allocation_sizes[allocation_id_], + TORCH_CHECK( + bytes == plan_->allocation_sizes[allocation_id_], "Got allocation request that does not match with the plan."); if (plan_->allocation_lifetimes[allocation_id_] == std::numeric_limits::max()) { @@ -352,9 +362,8 @@ void* CPUProfilingAllocator::allocate(const size_t bytes) { allocation_id_++; return c10::alloc_cpu(bytes); } - void* ptr = - reinterpret_cast(blob_) + - plan_->allocation_offsets[allocation_id_]; + void* ptr = reinterpret_cast(blob_) + + plan_->allocation_offsets[allocation_id_]; allocation_ptr_to_id_[ptr] = allocation_id_; allocation_id_++; return ptr; @@ -364,8 +373,8 @@ void CPUProfilingAllocator::free(void* const ptr) { auto it = allocation_ptr_to_id_.find(ptr); if (it == allocation_ptr_to_id_.end()) { // Either - // 1. Allocation that was made outside the validation scope is being freed here - // or + // 1. Allocation that was made outside the validation scope is being freed + // here or // 2. Allocation that is not managed by profiling allocator is being freed. // Example of the second type // Tensor out; @@ -380,7 +389,8 @@ void CPUProfilingAllocator::free(void* const ptr) { return; } auto id = it->second; - TORCH_CHECK(id < plan_->allocation_lifetimes.size(), + TORCH_CHECK( + id < plan_->allocation_lifetimes.size(), "Freeing allocation that is not accordingly to the plan."); auto lifetime_id = plan_->allocation_lifetimes[id]; TORCH_CHECK( @@ -397,10 +407,10 @@ CPUProfilingAllocator::~CPUProfilingAllocator() { c10::free_cpu(blob_); } -WithProfileAllocationsGuard::WithProfileAllocationsGuard( - AllocationPlan* plan) { +WithProfileAllocationsGuard::WithProfileAllocationsGuard(AllocationPlan* plan) { // Nesting of allocation profiling does not seem meaningful. - TORCH_CHECK(allocation_planner == nullptr, + TORCH_CHECK( + allocation_planner == nullptr, "Nesting profiling allocations is not supported."); planner_ = std::make_unique(plan); planner_->clear(); @@ -413,9 +423,11 @@ WithProfileAllocationsGuard::~WithProfileAllocationsGuard() { } WithValidateAllocationPlanGuard::WithValidateAllocationPlanGuard( - AllocationPlan* plan, bool* success) { + AllocationPlan* plan, + bool* success) { // Nesting of allocation profiling does not seem meaningful. - TORCH_CHECK(allocation_planner == nullptr, + TORCH_CHECK( + allocation_planner == nullptr, "Nesting profiling allocations is not supported."); planner_ = std::make_unique(plan, true); success_ = success; @@ -432,9 +444,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner() { } WithProfilingAllocatorGuard::WithProfilingAllocatorGuard( - CPUProfilingAllocator* allocator, const AllocationPlan* plan) { + CPUProfilingAllocator* allocator, + const AllocationPlan* plan) { // Nesting of profiling allocator is not supported. - TORCH_CHECK(profiling_allocator == nullptr, + TORCH_CHECK( + profiling_allocator == nullptr, "Nesting profiling allocators is not supported."); profiling_allocator = allocator; profiling_allocator->set_plan(plan); diff --git a/c10/mobile/CPUProfilingAllocator.h b/c10/mobile/CPUProfilingAllocator.h index 4a7e79fe285..5112691a62d 100644 --- a/c10/mobile/CPUProfilingAllocator.h +++ b/c10/mobile/CPUProfilingAllocator.h @@ -16,29 +16,30 @@ namespace c10 { * Given a sequence of allocations in a thread, AllocationPlan records * 1. size of each allocation * 2. Lifetime of each allocation. - * 3. allocation offsets: Memory offset for each allocation in a single blob of memory + * 3. allocation offsets: Memory offset for each allocation in a single blob of + * memory * 4. Total size of a blob of memory required to satisfy all the allocations. */ class C10_API AllocationPlan { - private: - // Records size of each allocation by their sequential allocation ids. - std::vector allocation_sizes; - // This maps one allocation id (X) to another allocation id (Y). - // Allocation X is alive until allocation Y. From allocation Y onwards - // allocation X is not referenced. - // Thus Y is the id of the first allocation after X is freed. - // NB: When an allocation is recorded, along with recording its size, - // we also set the lifetime to be numeric_limits::max() - // This is to track allocations that are made during the scope of - // profiling but were not freed until after the scope ended. - // Such allocations are not managed by profiling allocator. - std::vector allocation_lifetimes; - // Maps an allocation to some offset in a blob of memory. - std::vector allocation_offsets; - uint64_t total_size{0}; - void clear(); - friend class AllocationPlanner; - friend class CPUProfilingAllocator; + private: + // Records size of each allocation by their sequential allocation ids. + std::vector allocation_sizes; + // This maps one allocation id (X) to another allocation id (Y). + // Allocation X is alive until allocation Y. From allocation Y onwards + // allocation X is not referenced. + // Thus Y is the id of the first allocation after X is freed. + // NB: When an allocation is recorded, along with recording its size, + // we also set the lifetime to be numeric_limits::max() + // This is to track allocations that are made during the scope of + // profiling but were not freed until after the scope ended. + // Such allocations are not managed by profiling allocator. + std::vector allocation_lifetimes; + // Maps an allocation to some offset in a blob of memory. + std::vector allocation_offsets; + uint64_t total_size{0}; + void clear(); + friend class AllocationPlanner; + friend class CPUProfilingAllocator; }; /* @@ -46,43 +47,45 @@ class C10_API AllocationPlan { * used to establish lifetime of allocations. */ class C10_API AllocationPlanner { - private: - AllocationPlan* allocation_plan_{nullptr}; - // Maps allocated ptr to its allocation id. - // This is used when freeing the memory to lookup the allocation id - // in order to establish the lifetime of a particular allocation. - ska::flat_hash_map allocation_ptr_to_id_; - uint64_t allocation_id_{0}; - bool validation_mode_{false}; + private: + AllocationPlan* allocation_plan_{nullptr}; + // Maps allocated ptr to its allocation id. + // This is used when freeing the memory to lookup the allocation id + // in order to establish the lifetime of a particular allocation. + ska::flat_hash_map allocation_ptr_to_id_; + uint64_t allocation_id_{0}; + bool validation_mode_{false}; - bool validate_allocation(const uint64_t size, const void* ptr); - bool validate_free(const void* ptr); - public: - bool validation_success{true}; + bool validate_allocation(const uint64_t size, const void* ptr); + bool validate_free(const void* ptr); - AllocationPlanner() = delete; - AllocationPlanner(AllocationPlan* plan, bool validate = false) : - allocation_plan_(plan), validation_mode_(validate) {} - void record_allocation(const uint64_t size, const void* ptr); - void record_free(const void* ptr); - void formulate_plan(); - void clear(); + public: + bool validation_success{true}; + + AllocationPlanner() = delete; + AllocationPlanner(AllocationPlan* plan, bool validate = false) + : allocation_plan_(plan), validation_mode_(validate) {} + void record_allocation(const uint64_t size, const void* ptr); + void record_free(const void* ptr); + void formulate_plan(); + void clear(); }; // NOT THREAD SAFE profiling allocator. class C10_API CPUProfilingAllocator { - private: - const AllocationPlan* plan_{nullptr}; - uint64_t allocation_id_{0}; - uint64_t current_size_{0}; - void* blob_{nullptr}; - ska::flat_hash_map allocation_ptr_to_id_; - public: - ~CPUProfilingAllocator(); - void set_plan(const AllocationPlan* plan); - void unset_plan(); - void* allocate(const size_t bytes); - void free(void* const ptr); + private: + const AllocationPlan* plan_{nullptr}; + uint64_t allocation_id_{0}; + uint64_t current_size_{0}; + void* blob_{nullptr}; + ska::flat_hash_map allocation_ptr_to_id_; + + public: + ~CPUProfilingAllocator(); + void set_plan(const AllocationPlan* plan); + void unset_plan(); + void* allocate(const size_t bytes); + void free(void* const ptr); }; /* @@ -95,11 +98,12 @@ class C10_API CPUProfilingAllocator { * plan now contains allocation plan. */ class C10_API WithProfileAllocationsGuard { - public: - WithProfileAllocationsGuard(AllocationPlan* plan); - ~WithProfileAllocationsGuard(); - private: - std::unique_ptr planner_; + public: + WithProfileAllocationsGuard(AllocationPlan* plan); + ~WithProfileAllocationsGuard(); + + private: + std::unique_ptr planner_; }; /* @@ -115,12 +119,13 @@ class C10_API WithProfileAllocationsGuard { * else for some inputs allocation pattern changed. */ class C10_API WithValidateAllocationPlanGuard { - public: - WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success); - ~WithValidateAllocationPlanGuard(); - private: - std::unique_ptr planner_; - bool* success_; + public: + WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success); + ~WithValidateAllocationPlanGuard(); + + private: + std::unique_ptr planner_; + bool* success_; }; AllocationPlanner* GetThreadLocalAllocationPlanner(); @@ -138,10 +143,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner(); * } */ class C10_API WithProfilingAllocatorGuard { - public: - WithProfilingAllocatorGuard( - CPUProfilingAllocator* allocator, const AllocationPlan* plan); - ~WithProfilingAllocatorGuard(); + public: + WithProfilingAllocatorGuard( + CPUProfilingAllocator* allocator, + const AllocationPlan* plan); + ~WithProfilingAllocatorGuard(); }; CPUProfilingAllocator* GetThreadLocalProfilingAllocator(); diff --git a/c10/test/core/CompileTimeFunctionPointer_test.cpp b/c10/test/core/CompileTimeFunctionPointer_test.cpp index 64de518cc4f..3caf8cc4915 100644 --- a/c10/test/core/CompileTimeFunctionPointer_test.cpp +++ b/c10/test/core/CompileTimeFunctionPointer_test.cpp @@ -5,63 +5,71 @@ namespace test_is_compile_time_function_pointer { static_assert(!c10::is_compile_time_function_pointer::value, ""); void dummy() {} -static_assert(c10::is_compile_time_function_pointer::value, ""); -} +static_assert( + c10::is_compile_time_function_pointer::value, + ""); +} // namespace test_is_compile_time_function_pointer namespace test_access_through_type { - void dummy() {} - using dummy_ptr = TORCH_FN_TYPE(dummy); - static_assert(c10::is_compile_time_function_pointer::value, ""); - static_assert(dummy_ptr::func_ptr() == &dummy, ""); - static_assert(std::is_same::value, ""); -} +void dummy() {} +using dummy_ptr = TORCH_FN_TYPE(dummy); +static_assert(c10::is_compile_time_function_pointer::value, ""); +static_assert(dummy_ptr::func_ptr() == &dummy, ""); +static_assert(std::is_same::value, ""); +} // namespace test_access_through_type namespace test_access_through_value { - void dummy() {} - constexpr auto dummy_ptr = TORCH_FN(dummy); - static_assert(dummy_ptr.func_ptr() == &dummy, ""); - static_assert(std::is_same::value, ""); -} +void dummy() {} +constexpr auto dummy_ptr = TORCH_FN(dummy); +static_assert(dummy_ptr.func_ptr() == &dummy, ""); +static_assert(std::is_same::value, ""); +} // namespace test_access_through_value namespace test_access_through_type_also_works_if_specified_as_pointer { - void dummy() {} - using dummy_ptr = TORCH_FN_TYPE(&dummy); - static_assert(c10::is_compile_time_function_pointer::value, ""); - static_assert(dummy_ptr::func_ptr() == &dummy, ""); - static_assert(std::is_same::value, ""); -} +void dummy() {} +using dummy_ptr = TORCH_FN_TYPE(&dummy); +static_assert(c10::is_compile_time_function_pointer::value, ""); +static_assert(dummy_ptr::func_ptr() == &dummy, ""); +static_assert(std::is_same::value, ""); +} // namespace test_access_through_type_also_works_if_specified_as_pointer namespace test_access_through_value_also_works_if_specified_as_pointer { - void dummy() {} - constexpr auto dummy_ptr = TORCH_FN(&dummy); - static_assert(dummy_ptr.func_ptr() == &dummy, ""); - static_assert(std::is_same::value, ""); -} +void dummy() {} +constexpr auto dummy_ptr = TORCH_FN(&dummy); +static_assert(dummy_ptr.func_ptr() == &dummy, ""); +static_assert(std::is_same::value, ""); +} // namespace test_access_through_value_also_works_if_specified_as_pointer namespace test_run_through_type { - int add(int a, int b) {return a + b;} - using Add = TORCH_FN_TYPE(add); - template struct Executor { - int execute(int a, int b) { - return Func::func_ptr()(a, b); - } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) { - Executor executor; - EXPECT_EQ(3, executor.execute(1, 2)); - } +int add(int a, int b) { + return a + b; } +using Add = TORCH_FN_TYPE(add); +template +struct Executor { + int execute(int a, int b) { + return Func::func_ptr()(a, b); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) { + Executor executor; + EXPECT_EQ(3, executor.execute(1, 2)); +} +} // namespace test_run_through_type namespace test_run_through_value { - int add(int a, int b) {return a + b;} - template int execute(Func, int a, int b) { - return Func::func_ptr()(a, b); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) { - EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); - } +int add(int a, int b) { + return a + b; } +template +int execute(Func, int a, int b) { + return Func::func_ptr()(a, b); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) { + EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); +} +} // namespace test_run_through_value diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp index 1d931df033c..57f4f596b1e 100644 --- a/c10/test/core/DispatchKeySet_test.cpp +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -9,7 +9,8 @@ using namespace c10; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(DispatchKeySet, Empty) { DispatchKeySet empty_set; - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); + i++) { auto tid = static_cast(i); ASSERT_FALSE(empty_set.has(tid)); } @@ -21,7 +22,8 @@ TEST(DispatchKeySet, Empty) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(DispatchKeySet, Singleton) { - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); + i++) { auto tid = static_cast(i); DispatchKeySet sing(tid); ASSERT_EQ(sing, sing); @@ -37,8 +39,11 @@ TEST(DispatchKeySet, Singleton) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(DispatchKeySet, Doubleton) { - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { - for (uint8_t j = i + 1; j < static_cast(DispatchKey::NumDispatchKeys); j++) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); + i++) { + for (uint8_t j = i + 1; + j < static_cast(DispatchKey::NumDispatchKeys); + j++) { ASSERT_LT(i, j); auto tid1 = static_cast(i); auto tid2 = static_cast(j); @@ -46,7 +51,7 @@ TEST(DispatchKeySet, Doubleton) { ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2)); ASSERT_TRUE(doub.has(tid1)); ASSERT_TRUE(doub.has(tid2)); - ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j + ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j } } } @@ -54,7 +59,8 @@ TEST(DispatchKeySet, Doubleton) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(DispatchKeySet, Full) { DispatchKeySet full(DispatchKeySet::FULL); - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); + i++) { auto tid = static_cast(i); ASSERT_TRUE(full.has(tid)); } @@ -103,13 +109,12 @@ TEST(DispatchKeySet, IteratorFull) { ASSERT_EQ(i, static_cast(DispatchKey::NumDispatchKeys) - 1); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(DispatchKeySet, IteratorRangeFull) { DispatchKeySet full_set(DispatchKeySet::FULL); uint8_t i = 0; - for (DispatchKey dispatch_key: full_set) { + for (DispatchKey dispatch_key : full_set) { i++; ASSERT_TRUE(dispatch_key == static_cast(i)); } @@ -126,17 +131,20 @@ TEST(DispatchKeySet, SpecificKeys) { static_cast(10), // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) static_cast(15), - }); + }); std::unordered_set visited_keys; - for(DispatchKey key: keyset) { + for (DispatchKey key : keyset) { visited_keys.insert(key); } ASSERT_EQ(visited_keys.size(), 3); - ASSERT_TRUE(visited_keys.find(static_cast(4)) != visited_keys.end()); - ASSERT_TRUE(visited_keys.find(static_cast(10)) != visited_keys.end()); - ASSERT_TRUE(visited_keys.find(static_cast(15)) != visited_keys.end()); + ASSERT_TRUE( + visited_keys.find(static_cast(4)) != visited_keys.end()); + ASSERT_TRUE( + visited_keys.find(static_cast(10)) != visited_keys.end()); + ASSERT_TRUE( + visited_keys.find(static_cast(15)) != visited_keys.end()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -145,9 +153,8 @@ TEST(DispatchKeySet, FailAtEndIterator) { uint64_t raw_repr = full_set.raw_repr(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW(DispatchKeySet::iterator( - &raw_repr, - static_cast(DispatchKey::NumDispatchKeys) + 1 - ), - c10::Error); + EXPECT_THROW( + DispatchKeySet::iterator( + &raw_repr, static_cast(DispatchKey::NumDispatchKeys) + 1), + c10::Error); } diff --git a/c10/test/core/impl/InlineDeviceGuard_test.cpp b/c10/test/core/impl/InlineDeviceGuard_test.cpp index f74648bc359..f0b88acc5f2 100644 --- a/c10/test/core/impl/InlineDeviceGuard_test.cpp +++ b/c10/test/core/impl/InlineDeviceGuard_test.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include using namespace c10; using namespace c10::impl; @@ -55,8 +55,8 @@ TEST(InlineDeviceGuard, Constructor) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InlineDeviceGuard, ConstructorError) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_ANY_THROW(InlineDeviceGuard> - g(Device(DeviceType::HIP, 1))); + EXPECT_ANY_THROW(InlineDeviceGuard> g( + Device(DeviceType::HIP, 1))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -110,7 +110,8 @@ TEST(InlineDeviceGuard, SetIndex) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i2); } -// -- InlineOptionalDeviceGuard -------------------------------------------------- +// -- InlineOptionalDeviceGuard +// -------------------------------------------------- using MaybeTestGuard = InlineOptionalDeviceGuard; diff --git a/c10/test/core/impl/InlineStreamGuard_test.cpp b/c10/test/core/impl/InlineStreamGuard_test.cpp index d3f8def9a5c..ef811b7e4ac 100644 --- a/c10/test/core/impl/InlineStreamGuard_test.cpp +++ b/c10/test/core/impl/InlineStreamGuard_test.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include using namespace c10; using namespace c10::impl; @@ -40,7 +40,6 @@ TEST(InlineStreamGuard, Constructor) { ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InlineStreamGuard, ResetStreamSameSameDevice) { TestGuardImpl::setDeviceIndex(0); @@ -101,7 +100,8 @@ TEST(InlineStreamGuard, ResetStreamDifferentDevice) { ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); } -// -- OptionalInlineStreamGuard ------------------------------------------------------- +// -- OptionalInlineStreamGuard +// ------------------------------------------------------- using OptionalTestGuard = InlineOptionalStreamGuard; @@ -180,7 +180,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) { ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); } -// -- InlineMultiStreamGuard ------------------------------------------------------- +// -- InlineMultiStreamGuard +// ------------------------------------------------------- using MultiTestGuard = InlineMultiStreamGuard; diff --git a/c10/test/core/impl/SizesAndStrides_test.cpp b/c10/test/core/impl/SizesAndStrides_test.cpp index 744c83da76f..4df51271883 100644 --- a/c10/test/core/impl/SizesAndStrides_test.cpp +++ b/c10/test/core/impl/SizesAndStrides_test.cpp @@ -6,12 +6,16 @@ using namespace c10; using namespace c10::impl; -static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) { - EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match"; +static void checkData( + const SizesAndStrides& sz, + IntArrayRef sizes, + IntArrayRef strides) { + EXPECT_EQ(sizes.size(), strides.size()) + << "bad test case: size() of sizes and strides don't match"; EXPECT_EQ(sz.size(), sizes.size()); int idx = 0; - for (auto x: sizes) { + for (auto x : sizes) { EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx; EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx; EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx; @@ -21,7 +25,7 @@ static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef EXPECT_EQ(sz.sizes_arrayref(), sizes); idx = 0; - for (auto x: strides) { + for (auto x : strides) { EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx; EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx; EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx; diff --git a/c10/test/util/Array_test.cpp b/c10/test/util/Array_test.cpp index 8f7365fe0d7..be265132d43 100644 --- a/c10/test/util/Array_test.cpp +++ b/c10/test/util/Array_test.cpp @@ -6,90 +6,92 @@ using c10::guts::to_array; namespace { namespace test_equals { - static_assert(array{{}} == array{{}}, ""); - static_assert(array{{2, 3, 4}} == array{{2, 3, 4}}, ""); - static_assert(!(array{{2, 3, 4}} == array{{1, 3, 4}}), ""); - static_assert(!(array{{2, 3, 4}} == array{{2, 1, 4}}), ""); - static_assert(!(array{{2, 3, 4}} == array{{2, 3, 1}}), ""); -} +static_assert(array{{}} == array{{}}, ""); +static_assert(array{{2, 3, 4}} == array{{2, 3, 4}}, ""); +static_assert(!(array{{2, 3, 4}} == array{{1, 3, 4}}), ""); +static_assert(!(array{{2, 3, 4}} == array{{2, 1, 4}}), ""); +static_assert(!(array{{2, 3, 4}} == array{{2, 3, 1}}), ""); +} // namespace test_equals namespace test_notequals { - static_assert(!(array{{}} != array{{}}), ""); - static_assert(!(array{{2, 3, 4}} != array{{2, 3, 4}}), ""); - static_assert(array{{2, 3, 4}} != array{{1, 3, 4}}, ""); - static_assert(array{{2, 3, 4}} != array{{2, 1, 4}}, ""); - static_assert(array{{2, 3, 4}} != array{{2, 3, 1}}, ""); -} +static_assert(!(array{{}} != array{{}}), ""); +static_assert(!(array{{2, 3, 4}} != array{{2, 3, 4}}), ""); +static_assert(array{{2, 3, 4}} != array{{1, 3, 4}}, ""); +static_assert(array{{2, 3, 4}} != array{{2, 1, 4}}, ""); +static_assert(array{{2, 3, 4}} != array{{2, 3, 1}}, ""); +} // namespace test_notequals namespace test_lessthan { - static_assert(!(array{{}} < array{{}}), ""); - static_assert(!(array{{2}} < array{{1}}), ""); - static_assert(array{{1}} < array{{2}}, ""); - static_assert(!(array{{1, 2, 3}} < array{{1, 2, 3}}), ""); - static_assert(array{{1, 2, 3}} < array{{2, 2, 3}}, ""); - static_assert(!(array{{1, 2, 3}} < array{{0, 2, 3}}), ""); - static_assert(array{{1, 2, 3}} < array{{1, 3, 3}}, ""); - static_assert(!(array{{1, 2, 3}} < array{{1, 1, 3}}), ""); - static_assert(array{{1, 2, 3}} < array{{1, 2, 4}}, ""); - static_assert(!(array{{1, 2, 3}} < array{{1, 2, 2}}), ""); -} +static_assert(!(array{{}} < array{{}}), ""); +static_assert(!(array{{2}} < array{{1}}), ""); +static_assert(array{{1}} < array{{2}}, ""); +static_assert(!(array{{1, 2, 3}} < array{{1, 2, 3}}), ""); +static_assert(array{{1, 2, 3}} < array{{2, 2, 3}}, ""); +static_assert(!(array{{1, 2, 3}} < array{{0, 2, 3}}), ""); +static_assert(array{{1, 2, 3}} < array{{1, 3, 3}}, ""); +static_assert(!(array{{1, 2, 3}} < array{{1, 1, 3}}), ""); +static_assert(array{{1, 2, 3}} < array{{1, 2, 4}}, ""); +static_assert(!(array{{1, 2, 3}} < array{{1, 2, 2}}), ""); +} // namespace test_lessthan namespace test_greaterthan { - static_assert(!(array{{}} > array{{}}), ""); - static_assert(!(array{{1}} > array{{2}}), ""); - static_assert(array{{2}} > array{{1}}, ""); - static_assert(!(array{{1, 2, 3}} > array{{1, 2, 3}}), ""); - static_assert(array{{2, 2, 3}} > array{{1, 2, 3}}, ""); - static_assert(!(array{{0, 2, 3}} > array{{1, 2, 3}}), ""); - static_assert(array{{1, 3, 3}} > array{{1, 2, 3}}, ""); - static_assert(!(array{{1, 1, 3}} > array{{1, 2, 3}}), ""); - static_assert(array{{1, 2, 4}} > array{{1, 2, 3}}, ""); - static_assert(!(array{{1, 2, 2}} > array{{1, 2, 3}}), ""); -} +static_assert(!(array{{}} > array{{}}), ""); +static_assert(!(array{{1}} > array{{2}}), ""); +static_assert(array{{2}} > array{{1}}, ""); +static_assert(!(array{{1, 2, 3}} > array{{1, 2, 3}}), ""); +static_assert(array{{2, 2, 3}} > array{{1, 2, 3}}, ""); +static_assert(!(array{{0, 2, 3}} > array{{1, 2, 3}}), ""); +static_assert(array{{1, 3, 3}} > array{{1, 2, 3}}, ""); +static_assert(!(array{{1, 1, 3}} > array{{1, 2, 3}}), ""); +static_assert(array{{1, 2, 4}} > array{{1, 2, 3}}, ""); +static_assert(!(array{{1, 2, 2}} > array{{1, 2, 3}}), ""); +} // namespace test_greaterthan namespace test_lessequals { - static_assert(array{{}} <= array{{}}, ""); - static_assert(!(array{{2}} <= array{{1}}), ""); - static_assert(array{{1}} <= array{{2}}, ""); - static_assert(array{{1, 2, 3}} <= array{{1, 2, 3}}, ""); - static_assert(array{{1, 2, 3}} <= array{{2, 2, 3}}, ""); - static_assert(!(array{{1, 2, 3}} <= array{{0, 2, 3}}), ""); - static_assert(array{{1, 2, 3}} <= array{{1, 3, 3}}, ""); - static_assert(!(array{{1, 2, 3}} <= array{{1, 1, 3}}), ""); - static_assert(array{{1, 2, 3}} <= array{{1, 2, 4}}, ""); - static_assert(!(array{{1, 2, 3}} <= array{{1, 2, 2}}), ""); -} +static_assert(array{{}} <= array{{}}, ""); +static_assert(!(array{{2}} <= array{{1}}), ""); +static_assert(array{{1}} <= array{{2}}, ""); +static_assert(array{{1, 2, 3}} <= array{{1, 2, 3}}, ""); +static_assert(array{{1, 2, 3}} <= array{{2, 2, 3}}, ""); +static_assert(!(array{{1, 2, 3}} <= array{{0, 2, 3}}), ""); +static_assert(array{{1, 2, 3}} <= array{{1, 3, 3}}, ""); +static_assert(!(array{{1, 2, 3}} <= array{{1, 1, 3}}), ""); +static_assert(array{{1, 2, 3}} <= array{{1, 2, 4}}, ""); +static_assert(!(array{{1, 2, 3}} <= array{{1, 2, 2}}), ""); +} // namespace test_lessequals namespace test_greaterequals { - static_assert(array{{}} >= array{{}}, ""); - static_assert(!(array{{1}} >= array{{2}}), ""); - static_assert(array{{2}} >= array{{1}}, ""); - static_assert(array{{1, 2, 3}} >= array{{1, 2, 3}}, ""); - static_assert(array{{2, 2, 3}} >= array{{1, 2, 3}}, ""); - static_assert(!(array{{0, 2, 3}} >= array{{1, 2, 3}}), ""); - static_assert(array{{1, 3, 3}} >= array{{1, 2, 3}}, ""); - static_assert(!(array{{1, 1, 3}} >= array{{1, 2, 3}}), ""); - static_assert(array{{1, 2, 4}} >= array{{1, 2, 3}}, ""); - static_assert(!(array{{1, 2, 2}} >= array{{1, 2, 3}}), ""); -} +static_assert(array{{}} >= array{{}}, ""); +static_assert(!(array{{1}} >= array{{2}}), ""); +static_assert(array{{2}} >= array{{1}}, ""); +static_assert(array{{1, 2, 3}} >= array{{1, 2, 3}}, ""); +static_assert(array{{2, 2, 3}} >= array{{1, 2, 3}}, ""); +static_assert(!(array{{0, 2, 3}} >= array{{1, 2, 3}}), ""); +static_assert(array{{1, 3, 3}} >= array{{1, 2, 3}}, ""); +static_assert(!(array{{1, 1, 3}} >= array{{1, 2, 3}}), ""); +static_assert(array{{1, 2, 4}} >= array{{1, 2, 3}}, ""); +static_assert(!(array{{1, 2, 2}} >= array{{1, 2, 3}}), ""); +} // namespace test_greaterequals namespace test_tail { - static_assert(array < int, 2 > {{3, 4}} == tail(array < int, 3 > {{2, 3, 4}}), ""); - static_assert(array < int, 0 > {{}} == tail(array < int, 1 > {{3}}), ""); -} +static_assert(array{{3, 4}} == tail(array{{2, 3, 4}}), ""); +static_assert(array{{}} == tail(array{{3}}), ""); +} // namespace test_tail namespace test_prepend { - static_assert(array < int, 3 > {{2, 3, 4}} == prepend(2, array < int, 2 > {{3, 4}}), ""); - static_assert(array < int, 1 > {{3}} == prepend(3, array < int, 0 > {{}}), ""); -} +static_assert( + array{{2, 3, 4}} == prepend(2, array{{3, 4}}), + ""); +static_assert(array{{3}} == prepend(3, array{{}}), ""); +} // namespace test_prepend namespace test_to_std_array { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - constexpr int obj2[3] = {3, 5, 6}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - static_assert(array < int, 3 > {{3, 5, 6}} == to_array(obj2), ""); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - static_assert(array < int, 3 > {{3, 5, 6}} == to_array({3, 5, 6}), ""); -} +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +constexpr int obj2[3] = {3, 5, 6}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static_assert(array{{3, 5, 6}} == to_array(obj2), ""); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static_assert(array{{3, 5, 6}} == to_array({3, 5, 6}), ""); +} // namespace test_to_std_array -} +} // namespace diff --git a/c10/test/util/C++17_test.cpp b/c10/test/util/C++17_test.cpp index 5f6cde46e9a..5a10d12f56c 100644 --- a/c10/test/util/C++17_test.cpp +++ b/c10/test/util/C++17_test.cpp @@ -12,7 +12,7 @@ static_assert(min(5, 3) == 3, ""); static_assert(min(3, 3) == 3, ""); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) static_assert(min(3.0, 3.1) == 3.0, ""); -} +} // namespace test_min namespace test_max { using c10::guts::max; @@ -23,7 +23,7 @@ static_assert(max(5, 3) == 5, ""); static_assert(max(3, 3) == 3, ""); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) static_assert(max(3.0, 3.1) == 3.1, ""); -} +} // namespace test_max namespace test_if_constexpr { @@ -31,129 +31,125 @@ using c10::guts::if_constexpr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, whenIsTrue_thenReturnsTrueCase) { - EXPECT_EQ(4, if_constexpr([](auto) { return 4; }, [](auto) { return 5; })); + EXPECT_EQ( + 4, if_constexpr([](auto) { return 4; }, [](auto) { return 5; })); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, whenIsFalse_thenReturnsFalseCase) { - EXPECT_EQ(5, if_constexpr([](auto) { return 4; }, [](auto) { return 5; })); + EXPECT_EQ( + 5, if_constexpr([](auto) { return 4; }, [](auto) { return 5; })); } struct MovableOnly final { - int value; + int value; - MovableOnly(int v) : value(v) {} - MovableOnly(MovableOnly&&) = default; - MovableOnly(const MovableOnly&) = delete; - MovableOnly& operator=(MovableOnly&&) = default; - MovableOnly& operator=(const MovableOnly&) = delete; + MovableOnly(int v) : value(v) {} + MovableOnly(MovableOnly&&) = default; + MovableOnly(const MovableOnly&) = delete; + MovableOnly& operator=(MovableOnly&&) = default; + MovableOnly& operator=(const MovableOnly&) = delete; }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, worksWithMovableOnlyTypes_withIdentityArg) { - EXPECT_EQ( - 4, - if_constexpr([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); }) - .value); - EXPECT_EQ( - 5, - if_constexpr([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); }) - .value); + EXPECT_EQ( + 4, + if_constexpr( + [](auto) { return MovableOnly(4); }, + [](auto) { return MovableOnly(5); }) + .value); + EXPECT_EQ( + 5, + if_constexpr( + [](auto) { return MovableOnly(4); }, + [](auto) { return MovableOnly(5); }) + .value); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, worksWithMovableOnlyTypes_withoutIdentityArg) { - EXPECT_EQ( - 4, - if_constexpr([] { return MovableOnly(4); }, [] { return MovableOnly(5); }) - .value); - EXPECT_EQ( - 5, - if_constexpr([] { return MovableOnly(4); }, [] { return MovableOnly(5); }) - .value); + EXPECT_EQ( + 4, + if_constexpr( + [] { return MovableOnly(4); }, [] { return MovableOnly(5); }) + .value); + EXPECT_EQ( + 5, + if_constexpr( + [] { return MovableOnly(4); }, [] { return MovableOnly(5); }) + .value); } struct MyClass1 { - int value; + int value; }; struct MyClass2 { - int val; + int val; }; -template +template int func(T t) { - return if_constexpr::value>( - [&](auto _) { return _(t).value; }, // this code is invalid for T == MyClass2 - [&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1 - ); + return if_constexpr::value>( + [&](auto _) { + return _(t).value; + }, // this code is invalid for T == MyClass2 + [&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1 + ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, otherCaseCanHaveInvalidCode) { - EXPECT_EQ(8, func(MyClass1{/* .value = */ 8})); - EXPECT_EQ(4, func(MyClass2{/* .val = */ 4})); + EXPECT_EQ(8, func(MyClass1{/* .value = */ 8})); + EXPECT_EQ(4, func(MyClass2{/* .val = */ 4})); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, worksWithoutElseCase_withIdentityArg) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - int var = 5; - if_constexpr( - [&](auto) { var = 3; } - ); - EXPECT_EQ(5, var); - if_constexpr( - [&](auto) { var = 3; } - ); - EXPECT_EQ(3, var); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + int var = 5; + if_constexpr([&](auto) { var = 3; }); + EXPECT_EQ(5, var); + if_constexpr([&](auto) { var = 3; }); + EXPECT_EQ(3, var); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, worksWithoutElseCase_withoutIdentityArg) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - int var = 5; - if_constexpr( - [&] { var = 3; } - ); - EXPECT_EQ(5, var); - if_constexpr( - [&] { var = 3; } - ); - EXPECT_EQ(3, var); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + int var = 5; + if_constexpr([&] { var = 3; }); + EXPECT_EQ(5, var); + if_constexpr([&] { var = 3; }); + EXPECT_EQ(3, var); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, returnTypeCanDiffer_withIdentityArg) { - auto a_string = if_constexpr( - [&](auto) -> int64_t { return 3; }, - [&](auto) -> std::string { return "3"; } - ); - static_assert(std::is_same::value, ""); + auto a_string = if_constexpr( + [&](auto) -> int64_t { return 3; }, + [&](auto) -> std::string { return "3"; }); + static_assert(std::is_same::value, ""); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto an_int = if_constexpr( - [&](auto) -> int64_t { return 3; }, - [&](auto) -> std::string { return "3"; } - ); - static_assert(std::is_same::value, ""); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto an_int = if_constexpr( + [&](auto) -> int64_t { return 3; }, + [&](auto) -> std::string { return "3"; }); + static_assert(std::is_same::value, ""); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(if_constexpr, returnTypeCanDiffer_withoutIdentityArg) { - auto a_string = if_constexpr( - [&] () -> int64_t { return 3; }, - [&] () -> std::string { return "3"; } - ); - static_assert(std::is_same::value, ""); + auto a_string = if_constexpr( + [&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; }); + static_assert(std::is_same::value, ""); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto an_int = if_constexpr( - [&] () -> int64_t { return 3; }, - [&] () -> std::string { return "3"; } - ); - static_assert(std::is_same::value, ""); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto an_int = if_constexpr( + [&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; }); + static_assert(std::is_same::value, ""); } -} -} +} // namespace test_if_constexpr +} // namespace diff --git a/c10/test/util/LeftRight_test.cpp b/c10/test/util/LeftRight_test.cpp index 7cdd06f9b8a..e113e9cb65f 100644 --- a/c10/test/util/LeftRight_test.cpp +++ b/c10/test/util/LeftRight_test.cpp @@ -10,259 +10,257 @@ TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) { LeftRight obj; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([] (int& obj) {obj = 5;}); - int read = obj.read([] (const int& obj) {return obj;}); + obj.write([](int& obj) { obj = 5; }); + int read = obj.read([](const int& obj) { return obj; }); EXPECT_EQ(5, read); // check changes are also present in background copy - obj.write([] (int&) {}); // this switches to the background copy - read = obj.read([] (const int& obj) {return obj;}); + obj.write([](int&) {}); // this switches to the background copy + read = obj.read([](const int& obj) { return obj; }); EXPECT_EQ(5, read); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) { - LeftRight> obj; + LeftRight> obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([] (vector& obj) {obj.push_back(5);}); - vector read = obj.read([] (const vector& obj) {return obj;}); - EXPECT_EQ((vector{5}), read); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + obj.write([](vector& obj) { obj.push_back(5); }); + vector read = obj.read([](const vector& obj) { return obj; }); + EXPECT_EQ((vector{5}), read); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([] (vector& obj) {obj.push_back(6);}); - read = obj.read([] (const vector& obj) {return obj;}); - EXPECT_EQ((vector{5, 6}), read); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + obj.write([](vector& obj) { obj.push_back(6); }); + read = obj.read([](const vector& obj) { return obj; }); + EXPECT_EQ((vector{5, 6}), read); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, givenVector_whenWritingReturnsValue_thenValueIsReturned) { - LeftRight> obj; + LeftRight> obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = obj.write([] (vector&) -> int {return 5;}); - static_assert(std::is_same::value, ""); - EXPECT_EQ(5, a); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = obj.write([](vector&) -> int { return 5; }); + static_assert(std::is_same::value, ""); + EXPECT_EQ(5, a); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, readsCanBeConcurrent) { - LeftRight obj; - std::atomic num_running_readers{0}; + LeftRight obj; + std::atomic num_running_readers{0}; - std::thread reader1([&] () { - obj.read([&] (const int&) { - ++num_running_readers; - while(num_running_readers.load() < 2) {} - }); + std::thread reader1([&]() { + obj.read([&](const int&) { + ++num_running_readers; + while (num_running_readers.load() < 2) { + } }); + }); - std::thread reader2([&] () { - obj.read([&] (const int&) { - ++num_running_readers; - while(num_running_readers.load() < 2) {} - }); + std::thread reader2([&]() { + obj.read([&](const int&) { + ++num_running_readers; + while (num_running_readers.load() < 2) { + } }); + }); - // the threads only finish after both entered the read function. - // if LeftRight didn't allow concurrency, this would cause a deadlock. - reader1.join(); - reader2.join(); + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + reader1.join(); + reader2.join(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) { - LeftRight obj; - std::atomic reader_running{false}; - std::atomic writer_running{false}; + LeftRight obj; + std::atomic reader_running{false}; + std::atomic writer_running{false}; - std::thread reader([&] () { - obj.read([&] (const int&) { - reader_running = true; - while(!writer_running.load()) {} - }); + std::thread reader([&]() { + obj.read([&](const int&) { + reader_running = true; + while (!writer_running.load()) { + } }); + }); - std::thread writer([&] () { - // run read first, write second - while (!reader_running.load()) {} + std::thread writer([&]() { + // run read first, write second + while (!reader_running.load()) { + } - obj.write([&] (int&) { - writer_running = true; - }); - }); + obj.write([&](int&) { writer_running = true; }); + }); - // the threads only finish after both entered the read function. - // if LeftRight didn't allow concurrency, this would cause a deadlock. - reader.join(); - writer.join(); + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + reader.join(); + writer.join(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) { - LeftRight obj; - std::atomic writer_running{false}; - std::atomic reader_running{false}; + LeftRight obj; + std::atomic writer_running{false}; + std::atomic reader_running{false}; - std::thread writer([&] () { - obj.read([&] (const int&) { - writer_running = true; - while(!reader_running.load()) {} - }); + std::thread writer([&]() { + obj.read([&](const int&) { + writer_running = true; + while (!reader_running.load()) { + } }); + }); - std::thread reader([&] () { - // run write first, read second - while (!writer_running.load()) {} + std::thread reader([&]() { + // run write first, read second + while (!writer_running.load()) { + } - obj.read([&] (const int&) { - reader_running = true; - }); - }); + obj.read([&](const int&) { reader_running = true; }); + }); - // the threads only finish after both entered the read function. - // if LeftRight didn't allow concurrency, this would cause a deadlock. - writer.join(); - reader.join(); + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + writer.join(); + reader.join(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) { - LeftRight obj; - std::atomic first_writer_started{false}; - std::atomic first_writer_finished{false}; + LeftRight obj; + std::atomic first_writer_started{false}; + std::atomic first_writer_finished{false}; - std::thread writer1([&] () { - obj.write([&] (int&) { - first_writer_started = true; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - first_writer_finished = true; - }); + std::thread writer1([&]() { + obj.write([&](int&) { + first_writer_started = true; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + first_writer_finished = true; }); + }); - std::thread writer2([&] () { - // make sure the other writer runs first - while (!first_writer_started.load()) {} + std::thread writer2([&]() { + // make sure the other writer runs first + while (!first_writer_started.load()) { + } - obj.write([&] (int&) { - // expect the other writer finished before this one starts - EXPECT_TRUE(first_writer_finished.load()); - }); + obj.write([&](int&) { + // expect the other writer finished before this one starts + EXPECT_TRUE(first_writer_finished.load()); }); + }); - writer1.join(); - writer2.join(); + writer1.join(); + writer2.join(); } namespace { class MyException : public std::exception {}; -} +} // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) { - LeftRight obj; + LeftRight obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - obj.read([](const int&) {throw MyException();}), - MyException - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW(obj.read([](const int&) { throw MyException(); }), MyException); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) { - LeftRight obj; + LeftRight obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - obj.write([](int&) {throw MyException();}), - MyException - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW(obj.write([](int&) { throw MyException(); }), MyException); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) { - LeftRight obj; +TEST( + LeftRightTest, + givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) { + LeftRight obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([](int& obj) {obj = 5;}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + obj.write([](int& obj) { obj = 5; }); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - obj.write([](int& obj) { - obj = 6; - throw MyException(); - }), - MyException - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW( + obj.write([](int& obj) { + obj = 6; + throw MyException(); + }), + MyException); - // check reading it returns old value - int read = obj.read([] (const int& obj) {return obj;}); - EXPECT_EQ(5, read); + // check reading it returns old value + int read = obj.read([](const int& obj) { return obj; }); + EXPECT_EQ(5, read); - // check changes are also present in background copy - obj.write([] (int&) {}); // this switches to the background copy - read = obj.read([] (const int& obj) {return obj;}); - EXPECT_EQ(5, read); + // check changes are also present in background copy + obj.write([](int&) {}); // this switches to the background copy + read = obj.read([](const int& obj) { return obj; }); + EXPECT_EQ(5, read); } // note: each write is executed twice, on the foreground and background copy. // We need to test a thrown exception in either call is handled correctly. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) { - LeftRight obj; +TEST( + LeftRightTest, + givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) { + LeftRight obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([](int& obj) {obj = 5;}); - bool write_called = false; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + obj.write([](int& obj) { obj = 5; }); + bool write_called = false; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - obj.write([&](int& obj) { - obj = 6; - if (write_called) { - // this is the second time the write callback is executed - throw MyException(); - } else { - write_called = true; - } - }), - MyException - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW( + obj.write([&](int& obj) { + obj = 6; + if (write_called) { + // this is the second time the write callback is executed + throw MyException(); + } else { + write_called = true; + } + }), + MyException); - // check reading it returns new value - int read = obj.read([] (const int& obj) {return obj;}); - EXPECT_EQ(6, read); + // check reading it returns new value + int read = obj.read([](const int& obj) { return obj; }); + EXPECT_EQ(6, read); - // check changes are also present in background copy - obj.write([] (int&) {}); // this switches to the background copy - read = obj.read([] (const int& obj) {return obj;}); - EXPECT_EQ(6, read); + // check changes are also present in background copy + obj.write([](int&) {}); // this switches to the background copy + read = obj.read([](const int& obj) { return obj; }); + EXPECT_EQ(6, read); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) { - LeftRight> obj; + LeftRight> obj; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - obj.write([](vector& obj) {obj.push_back(5);}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + obj.write([](vector& obj) { obj.push_back(5); }); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - obj.write([](vector& obj) { - obj.push_back(6); - throw MyException(); - }), - MyException - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW( + obj.write([](vector& obj) { + obj.push_back(6); + throw MyException(); + }), + MyException); - // check reading it returns old value - vector read = obj.read([] (const vector& obj) {return obj;}); - EXPECT_EQ((vector{5}), read); + // check reading it returns old value + vector read = obj.read([](const vector& obj) { return obj; }); + EXPECT_EQ((vector{5}), read); - // check changes are also present in background copy - obj.write([] (vector&) {}); // this switches to the background copy - read = obj.read([] (const vector& obj) {return obj;}); - EXPECT_EQ((vector{5}), read); + // check changes are also present in background copy + obj.write([](vector&) {}); // this switches to the background copy + read = obj.read([](const vector& obj) { return obj; }); + EXPECT_EQ((vector{5}), read); } diff --git a/c10/test/util/Metaprogramming_test.cpp b/c10/test/util/Metaprogramming_test.cpp index f6cfa715be2..bcce22d634b 100644 --- a/c10/test/util/Metaprogramming_test.cpp +++ b/c10/test/util/Metaprogramming_test.cpp @@ -1,580 +1,820 @@ -#include #include +#include #include #include - using namespace c10::guts; namespace { namespace test_function_traits { - static_assert(std::is_same::return_type>::value, ""); - static_assert(std::is_same::return_type>::value, ""); - static_assert(std::is_same, typename function_traits::parameter_types>::value, ""); - static_assert(std::is_same, typename function_traits::parameter_types>::value, ""); +static_assert( + std::is_same< + void, + typename function_traits::return_type>::value, + ""); +static_assert( + std::is_same::return_type>:: + value, + ""); +static_assert( + std::is_same< + typelist::typelist, + typename function_traits::parameter_types>::value, + ""); +static_assert( + std::is_same< + typelist::typelist, + typename function_traits::parameter_types>::value, + ""); - static_assert(std::is_same>::return_type>::value, ""); - static_assert(std::is_same>::return_type>::value, ""); - static_assert(std::is_same, typename make_function_traits_t>::parameter_types>::value, ""); - static_assert(std::is_same, typename make_function_traits_t>::parameter_types>::value, ""); - static_assert(std::is_same>::func_type>::value, ""); - static_assert(std::is_same>::func_type>::value, ""); -} +static_assert( + std::is_same< + bool, + typename make_function_traits_t>:: + return_type>::value, + ""); +static_assert( + std::is_same< + void, + typename make_function_traits_t>:: + return_type>::value, + ""); +static_assert( + std::is_same< + typelist::typelist, + typename make_function_traits_t>:: + parameter_types>::value, + ""); +static_assert( + std::is_same< + typelist::typelist, + typename make_function_traits_t>:: + parameter_types>::value, + ""); +static_assert( + std::is_same< + bool(int, float), + typename make_function_traits_t>:: + func_type>::value, + ""); +static_assert( + std::is_same< + void(int, float), + typename make_function_traits_t>:: + func_type>::value, + ""); +} // namespace test_function_traits struct MovableOnly { - constexpr MovableOnly(int val_): val(val_) {/* no default constructor */} - MovableOnly(const MovableOnly&) = delete; - MovableOnly(MovableOnly&&) = default; - MovableOnly& operator=(const MovableOnly&) = delete; - MovableOnly& operator=(MovableOnly&&) = default; + constexpr MovableOnly(int val_) : val(val_) { /* no default constructor */ + } + MovableOnly(const MovableOnly&) = delete; + MovableOnly(MovableOnly&&) = default; + MovableOnly& operator=(const MovableOnly&) = delete; + MovableOnly& operator=(MovableOnly&&) = default; - friend bool operator==(const MovableOnly& lhs, const MovableOnly& rhs) {return lhs.val == rhs.val;} -private: - int val; + friend bool operator==(const MovableOnly& lhs, const MovableOnly& rhs) { + return lhs.val == rhs.val; + } + + private: + int val; }; -template using is_my_movable_only_class = std::is_same>>; +template +using is_my_movable_only_class = + std::is_same>>; struct CopyCounting { - int move_count; - int copy_count; + int move_count; + int copy_count; - CopyCounting(): move_count(0), copy_count(0) {} - CopyCounting(const CopyCounting& rhs): move_count(rhs.move_count), copy_count(rhs.copy_count + 1) {} - CopyCounting(CopyCounting&& rhs): move_count(rhs.move_count + 1), copy_count(rhs.copy_count) {} - CopyCounting& operator=(const CopyCounting& rhs) { - move_count = rhs.move_count; - copy_count = rhs.copy_count + 1; - return *this; - } - CopyCounting& operator=(CopyCounting&& rhs) { - move_count = rhs.move_count + 1; - copy_count = rhs.copy_count; - return *this; - } + CopyCounting() : move_count(0), copy_count(0) {} + CopyCounting(const CopyCounting& rhs) + : move_count(rhs.move_count), copy_count(rhs.copy_count + 1) {} + CopyCounting(CopyCounting&& rhs) + : move_count(rhs.move_count + 1), copy_count(rhs.copy_count) {} + CopyCounting& operator=(const CopyCounting& rhs) { + move_count = rhs.move_count; + copy_count = rhs.copy_count + 1; + return *this; + } + CopyCounting& operator=(CopyCounting&& rhs) { + move_count = rhs.move_count + 1; + copy_count = rhs.copy_count; + return *this; + } }; -template using is_my_copy_counting_class = std::is_same>>; +template +using is_my_copy_counting_class = + std::is_same>>; namespace test_extract_arg_by_filtered_index { - class MyClass {}; +class MyClass {}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a1 = extract_arg_by_filtered_index(3, "bla", MyClass(), 4, nullptr, 5); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a2 = extract_arg_by_filtered_index(3, "bla", MyClass(), 4, nullptr, 5); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a3 = extract_arg_by_filtered_index(3, "bla", MyClass(), 4, nullptr, 5); - EXPECT_EQ(3, a1); - EXPECT_EQ(4, a2); - EXPECT_EQ(5, a3); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_singleInput) { - auto a1 = extract_arg_by_filtered_index(3); - EXPECT_EQ(3, a1); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_movableOnly) { - MovableOnly a1 = extract_arg_by_filtered_index(3, MovableOnly(3), "test", MovableOnly(1)); - MovableOnly a2 = extract_arg_by_filtered_index(3, MovableOnly(3), "test", MovableOnly(1)); - EXPECT_EQ(MovableOnly(3), a1); - EXPECT_EQ(MovableOnly(1), a2); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_onlyCopiesIfNecessary) { - CopyCounting source; - CopyCounting source2; - CopyCounting a1 = extract_arg_by_filtered_index(3, CopyCounting(), "test", source, std::move(source2)); - // NOLINTNEXTLINE(bugprone-use-after-move) - CopyCounting a2 = extract_arg_by_filtered_index(3, CopyCounting(), "test", source, std::move(source2)); - // NOLINTNEXTLINE(bugprone-use-after-move) - CopyCounting a3 = extract_arg_by_filtered_index(3, CopyCounting(), "test", source, std::move(source2)); - EXPECT_EQ(1, a1.move_count); - EXPECT_EQ(0, a1.copy_count); - EXPECT_EQ(0, a2.move_count); - EXPECT_EQ(1, a3.move_count); - EXPECT_EQ(0, a3.copy_count); - EXPECT_EQ(1, a2.copy_count); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_onlyMovesIfNecessary) { - CopyCounting source; - CopyCounting source2; - CopyCounting&& a1 = extract_arg_by_filtered_index(3, std::move(source), "test", std::move(source2)); - // NOLINTNEXTLINE(bugprone-use-after-move) - CopyCounting a2 = extract_arg_by_filtered_index(3, std::move(source), "test", std::move(source2)); - EXPECT_EQ(0, a1.move_count); - EXPECT_EQ(0, a1.copy_count); - EXPECT_EQ(1, a2.move_count); - EXPECT_EQ(0, a2.copy_count); - } - - template using is_true = std::true_type; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_keepsLValueReferencesIntact) { - MyClass obj; - MyClass& a1 = extract_arg_by_filtered_index(3, obj, "test", obj); - EXPECT_EQ(&obj, &a1); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, ExtractArgByFilteredIndex) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a1 = extract_arg_by_filtered_index( + 3, "bla", MyClass(), 4, nullptr, 5); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a2 = extract_arg_by_filtered_index( + 3, "bla", MyClass(), 4, nullptr, 5); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a3 = extract_arg_by_filtered_index( + 3, "bla", MyClass(), 4, nullptr, 5); + EXPECT_EQ(3, a1); + EXPECT_EQ(4, a2); + EXPECT_EQ(5, a3); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_singleInput) { + auto a1 = extract_arg_by_filtered_index(3); + EXPECT_EQ(3, a1); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_movableOnly) { + MovableOnly a1 = extract_arg_by_filtered_index( + 3, MovableOnly(3), "test", MovableOnly(1)); + MovableOnly a2 = extract_arg_by_filtered_index( + 3, MovableOnly(3), "test", MovableOnly(1)); + EXPECT_EQ(MovableOnly(3), a1); + EXPECT_EQ(MovableOnly(1), a2); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_onlyCopiesIfNecessary) { + CopyCounting source; + CopyCounting source2; + CopyCounting a1 = extract_arg_by_filtered_index( + 3, CopyCounting(), "test", source, std::move(source2)); + // NOLINTNEXTLINE(bugprone-use-after-move) + CopyCounting a2 = extract_arg_by_filtered_index( + 3, CopyCounting(), "test", source, std::move(source2)); + // NOLINTNEXTLINE(bugprone-use-after-move) + CopyCounting a3 = extract_arg_by_filtered_index( + 3, CopyCounting(), "test", source, std::move(source2)); + EXPECT_EQ(1, a1.move_count); + EXPECT_EQ(0, a1.copy_count); + EXPECT_EQ(0, a2.move_count); + EXPECT_EQ(1, a3.move_count); + EXPECT_EQ(0, a3.copy_count); + EXPECT_EQ(1, a2.copy_count); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, ExtractArgByFilteredIndex_onlyMovesIfNecessary) { + CopyCounting source; + CopyCounting source2; + CopyCounting&& a1 = + extract_arg_by_filtered_index( + 3, std::move(source), "test", std::move(source2)); + // NOLINTNEXTLINE(bugprone-use-after-move) + CopyCounting a2 = extract_arg_by_filtered_index( + 3, std::move(source), "test", std::move(source2)); + EXPECT_EQ(0, a1.move_count); + EXPECT_EQ(0, a1.copy_count); + EXPECT_EQ(1, a2.move_count); + EXPECT_EQ(0, a2.copy_count); +} + +template +using is_true = std::true_type; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST( + MetaprogrammingTest, + ExtractArgByFilteredIndex_keepsLValueReferencesIntact) { + MyClass obj; + MyClass& a1 = extract_arg_by_filtered_index(3, obj, "test", obj); + EXPECT_EQ(&obj, &a1); +} +} // namespace test_extract_arg_by_filtered_index + namespace test_filter_map { - class MyClass {}; +class MyClass {}; - struct map_to_double { - template constexpr double operator()(T a) const { - return static_cast(a); - } - }; +struct map_to_double { + template + constexpr double operator()(T a) const { + return static_cast(a); + } +}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = filter_map(map_to_double(), 3, "bla", MyClass(), 4, nullptr, 5); - static_assert(std::is_same, decltype(result)>::value, ""); - constexpr array expected{{3.0, 4.0, 5.0}}; - EXPECT_EQ(expected, result); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap_emptyInput) { - auto result = filter_map(map_to_double()); - static_assert(std::is_same, decltype(result)>::value, ""); - constexpr array expected{{}}; - EXPECT_EQ(expected, result); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap_emptyOutput) { - auto result = filter_map(map_to_double(), "bla", MyClass(), nullptr); - static_assert(std::is_same, decltype(result)>::value, ""); - constexpr array expected{{}}; - EXPECT_EQ(expected, result); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap_movableOnly_byRValue) { - struct map_movable_by_rvalue { - MovableOnly operator()(MovableOnly&& a) const { - return std::move(a); - } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = filter_map(map_movable_by_rvalue(), MovableOnly(5), "bla", nullptr, 3, MovableOnly(2)); - static_assert(std::is_same, decltype(result)>::value, ""); - constexpr array expected {{MovableOnly(5), MovableOnly(2)}}; - EXPECT_EQ(expected, result); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap_movableOnly_byValue) { - struct map_movable_by_lvalue { - MovableOnly operator()(MovableOnly a) const { - return a; - } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = filter_map(map_movable_by_lvalue(), MovableOnly(5), "bla", nullptr, 3, MovableOnly(2)); - static_assert(std::is_same, decltype(result)>::value, ""); - constexpr array expected {{MovableOnly(5), MovableOnly(2)}}; - EXPECT_EQ(expected, result); - } - - // See https://github.com/pytorch/pytorch/issues/35546 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, DISABLED_ON_WINDOWS(FilterMap_onlyCopiesIfNecessary)) { - struct map_copy_counting_by_copy { - CopyCounting operator()(CopyCounting v) const { - return v; - } - }; - - CopyCounting source; - CopyCounting source2; - auto result = filter_map(map_copy_counting_by_copy(), CopyCounting(), "bla", nullptr, 3, source, std::move(source2)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(0, result[0].copy_count); - EXPECT_EQ(2, result[0].move_count); - EXPECT_EQ(1, result[1].copy_count); - EXPECT_EQ(1, result[1].move_count); - EXPECT_EQ(0, result[2].copy_count); - EXPECT_EQ(2, result[2].move_count); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, DISABLED_ON_WINDOWS(FilterMap_onlyMovesIfNecessary_1)) { - struct map_copy_counting_by_move { - CopyCounting operator()(CopyCounting&& v) const { - return std::move(v); - } - }; - - CopyCounting source; - auto result = filter_map(map_copy_counting_by_move(), CopyCounting(), "bla", nullptr, 3, std::move(source)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(0, result[0].copy_count); - EXPECT_EQ(1, result[0].move_count); - EXPECT_EQ(0, result[1].copy_count); - EXPECT_EQ(1, result[1].move_count); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, FilterMap_onlyMovesIfNecessary_2) { - struct map_copy_counting_by_pointer { - const CopyCounting* operator()(const CopyCounting& v) const { - return &v; - } - }; - - CopyCounting source1; - CopyCounting source2; - auto result = filter_map(map_copy_counting_by_pointer(), "bla", nullptr, 3, source1, std::move(source2)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(0, result[0]->copy_count); - EXPECT_EQ(0, result[0]->move_count); - EXPECT_EQ(0, result[1]->copy_count); - EXPECT_EQ(0, result[1]->move_count); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = filter_map( + map_to_double(), 3, "bla", MyClass(), 4, nullptr, 5); + static_assert(std::is_same, decltype(result)>::value, ""); + constexpr array expected{{3.0, 4.0, 5.0}}; + EXPECT_EQ(expected, result); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap_emptyInput) { + auto result = filter_map(map_to_double()); + static_assert(std::is_same, decltype(result)>::value, ""); + constexpr array expected{{}}; + EXPECT_EQ(expected, result); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap_emptyOutput) { + auto result = filter_map( + map_to_double(), "bla", MyClass(), nullptr); + static_assert(std::is_same, decltype(result)>::value, ""); + constexpr array expected{{}}; + EXPECT_EQ(expected, result); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap_movableOnly_byRValue) { + struct map_movable_by_rvalue { + MovableOnly operator()(MovableOnly&& a) const { + return std::move(a); + } + }; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = filter_map( + map_movable_by_rvalue(), + MovableOnly(5), + "bla", + nullptr, + 3, + MovableOnly(2)); + static_assert( + std::is_same, decltype(result)>::value, ""); + constexpr array expected{{MovableOnly(5), MovableOnly(2)}}; + EXPECT_EQ(expected, result); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap_movableOnly_byValue) { + struct map_movable_by_lvalue { + MovableOnly operator()(MovableOnly a) const { + return a; + } + }; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = filter_map( + map_movable_by_lvalue(), + MovableOnly(5), + "bla", + nullptr, + 3, + MovableOnly(2)); + static_assert( + std::is_same, decltype(result)>::value, ""); + constexpr array expected{{MovableOnly(5), MovableOnly(2)}}; + EXPECT_EQ(expected, result); +} + +// See https://github.com/pytorch/pytorch/issues/35546 +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST( + MetaprogrammingTest, + DISABLED_ON_WINDOWS(FilterMap_onlyCopiesIfNecessary)) { + struct map_copy_counting_by_copy { + CopyCounting operator()(CopyCounting v) const { + return v; + } + }; + + CopyCounting source; + CopyCounting source2; + auto result = filter_map( + map_copy_counting_by_copy(), + CopyCounting(), + "bla", + nullptr, + 3, + source, + std::move(source2)); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(0, result[0].copy_count); + EXPECT_EQ(2, result[0].move_count); + EXPECT_EQ(1, result[1].copy_count); + EXPECT_EQ(1, result[1].move_count); + EXPECT_EQ(0, result[2].copy_count); + EXPECT_EQ(2, result[2].move_count); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST( + MetaprogrammingTest, + DISABLED_ON_WINDOWS(FilterMap_onlyMovesIfNecessary_1)) { + struct map_copy_counting_by_move { + CopyCounting operator()(CopyCounting&& v) const { + return std::move(v); + } + }; + + CopyCounting source; + auto result = filter_map( + map_copy_counting_by_move(), + CopyCounting(), + "bla", + nullptr, + 3, + std::move(source)); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(0, result[0].copy_count); + EXPECT_EQ(1, result[0].move_count); + EXPECT_EQ(0, result[1].copy_count); + EXPECT_EQ(1, result[1].move_count); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, FilterMap_onlyMovesIfNecessary_2) { + struct map_copy_counting_by_pointer { + const CopyCounting* operator()(const CopyCounting& v) const { + return &v; + } + }; + + CopyCounting source1; + CopyCounting source2; + auto result = filter_map( + map_copy_counting_by_pointer(), + "bla", + nullptr, + 3, + source1, + std::move(source2)); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(0, result[0]->copy_count); + EXPECT_EQ(0, result[0]->move_count); + EXPECT_EQ(0, result[1]->copy_count); + EXPECT_EQ(0, result[1]->move_count); +} +} // namespace test_filter_map + namespace test_tuple_elements { - // note: not testing empty selection, as some compilers will raise - // "parameter set but not used" in tuple_elements(). a good example - // of the friction that comes with using these tools +// note: not testing empty selection, as some compilers will raise +// "parameter set but not used" in tuple_elements(). a good example +// of the friction that comes with using these tools - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleElements_subsetSelection) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_elements(x, std::index_sequence<0, 2>()); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto z = std::make_tuple(0, 2.0); - EXPECT_EQ(y, z); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleElements_reorderSelection) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_elements(x, std::index_sequence<0, 2, 1>()); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto z = std::make_tuple(0, 2.0, "HEY"); - EXPECT_EQ(y, z); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleElements_subsetSelection) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_elements(x, std::index_sequence<0, 2>()); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto z = std::make_tuple(0, 2.0); + EXPECT_EQ(y, z); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleElements_reorderSelection) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_elements(x, std::index_sequence<0, 2, 1>()); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto z = std::make_tuple(0, 2.0, "HEY"); + EXPECT_EQ(y, z); +} +} // namespace test_tuple_elements + namespace test_tuple_take { - // note: not testing empty prefix, see note on empty selection above. +// note: not testing empty prefix, see note on empty selection above. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleTake_nonemptyPrefix) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_take(x); - auto z = std::make_tuple(0, "HEY"); - EXPECT_EQ(y, z); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleTake_fullPrefix) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_take(x); - EXPECT_EQ(x, y); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleTake_negative) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_take(x); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto z = std::make_tuple("HEY", 2.0); - EXPECT_EQ(y, z); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleTake_nonemptyPrefix) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_take(x); + auto z = std::make_tuple(0, "HEY"); + EXPECT_EQ(y, z); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleTake_fullPrefix) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_take(x); + EXPECT_EQ(x, y); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleTake_negative) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_take(x); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto z = std::make_tuple("HEY", 2.0); + EXPECT_EQ(y, z); +} +} // namespace test_tuple_take + namespace test_tuple_slice { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleSlice_middle) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0, false); - auto y = tuple_slice(x); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto z = std::make_tuple("HEY", 2.0); - EXPECT_EQ(y, z); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleSlice_full) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_slice(x); - EXPECT_EQ(x, y); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleSlice_middle) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0, false); + auto y = tuple_slice(x); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto z = std::make_tuple("HEY", 2.0); + EXPECT_EQ(y, z); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleSlice_full) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_slice(x); + EXPECT_EQ(x, y); +} +} // namespace test_tuple_slice + namespace test_tuple_map { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_simple) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_map(std::tuple(3, 4, 5), [] (int32_t a) -> int16_t {return a+1;}); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(4, std::get<0>(result)); - EXPECT_EQ(5, std::get<1>(result)); - EXPECT_EQ(6, std::get<2>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_mapperTakesDifferentButConvertibleType) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_map(std::tuple(3, 4, 5), [] (int64_t a) -> int16_t {return a+1;}); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(4, std::get<0>(result)); - EXPECT_EQ(5, std::get<1>(result)); - EXPECT_EQ(6, std::get<2>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_mapperTakesConstRef) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_map(std::tuple(3, 4, 5), [] (const int32_t& a) -> int16_t {return a+1;}); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(4, std::get<0>(result)); - EXPECT_EQ(5, std::get<1>(result)); - EXPECT_EQ(6, std::get<2>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_mapsToDifferentTypes) { - struct Mapper { - std::string operator()(int32_t a) const { - return std::to_string(a); - } - int32_t operator()(const std::string& a) const { - return atoi(a.c_str()); - } - }; - auto result = tuple_map(std::tuple(3, "4"), Mapper()); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ("3", std::get<0>(result)); - EXPECT_EQ(4, std::get<1>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_differentiatesLRValueReferences) { - struct Mapper { - std::string operator()(std::string&& a) const { - return "moved"; - } - std::string operator()(const std::string& a) const { - return "copied"; - } - }; - std::string str1, str2; - auto result = tuple_map(std::tuple(str1, std::move(str2)), Mapper()); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ("copied", std::get<0>(result)); - EXPECT_EQ("moved", std::get<1>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_canWorkWithMovableOnlyType) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_map(std::tuple(MovableOnly(7)), [] (MovableOnly a) { return a; }); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(MovableOnly(7), std::get<0>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_doesntUnecessarilyCopyValues) { - auto result = tuple_map(std::tuple(CopyCounting()), [] (CopyCounting a) { return a; }); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(4, std::get<0>(result).move_count); - EXPECT_EQ(0, std::get<0>(result).copy_count); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_doesntUnecessarilyMoveValues) { - CopyCounting a; - auto result = tuple_map(std::tuple(std::move(a)), [] (CopyCounting&& a) -> CopyCounting&& { return std::move(a); }); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(&a, &std::get<0>(result)); - EXPECT_EQ(0, std::get<0>(result).move_count); - EXPECT_EQ(0, std::get<0>(result).copy_count); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleMap_canBeUsedWithAutoLambdas) { - struct A final { - int32_t func() { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - return 5; - } - }; - struct B final { - std::string func() { - return "5"; - } - }; - auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return a.func(); }); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(5, std::get<0>(result)); - EXPECT_EQ("5", std::get<1>(result)); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_simple) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_map( + std::tuple(3, 4, 5), + [](int32_t a) -> int16_t { return a + 1; }); + static_assert( + std::is_same, decltype(result)>:: + value, + ""); + EXPECT_EQ(4, std::get<0>(result)); + EXPECT_EQ(5, std::get<1>(result)); + EXPECT_EQ(6, std::get<2>(result)); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_mapperTakesDifferentButConvertibleType) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_map( + std::tuple(3, 4, 5), + [](int64_t a) -> int16_t { return a + 1; }); + static_assert( + std::is_same, decltype(result)>:: + value, + ""); + EXPECT_EQ(4, std::get<0>(result)); + EXPECT_EQ(5, std::get<1>(result)); + EXPECT_EQ(6, std::get<2>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_mapperTakesConstRef) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_map( + std::tuple(3, 4, 5), + [](const int32_t& a) -> int16_t { return a + 1; }); + static_assert( + std::is_same, decltype(result)>:: + value, + ""); + EXPECT_EQ(4, std::get<0>(result)); + EXPECT_EQ(5, std::get<1>(result)); + EXPECT_EQ(6, std::get<2>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_mapsToDifferentTypes) { + struct Mapper { + std::string operator()(int32_t a) const { + return std::to_string(a); + } + int32_t operator()(const std::string& a) const { + return atoi(a.c_str()); + } + }; + auto result = tuple_map(std::tuple(3, "4"), Mapper()); + static_assert( + std::is_same, decltype(result)>::value, + ""); + EXPECT_EQ("3", std::get<0>(result)); + EXPECT_EQ(4, std::get<1>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_differentiatesLRValueReferences) { + struct Mapper { + std::string operator()(std::string&& a) const { + return "moved"; + } + std::string operator()(const std::string& a) const { + return "copied"; + } + }; + std::string str1, str2; + auto result = tuple_map( + std::tuple(str1, std::move(str2)), + Mapper()); + static_assert( + std::is_same, decltype(result)>:: + value, + ""); + EXPECT_EQ("copied", std::get<0>(result)); + EXPECT_EQ("moved", std::get<1>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_canWorkWithMovableOnlyType) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_map( + std::tuple(MovableOnly(7)), [](MovableOnly a) { return a; }); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(MovableOnly(7), std::get<0>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_doesntUnecessarilyCopyValues) { + auto result = tuple_map( + std::tuple(CopyCounting()), + [](CopyCounting a) { return a; }); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(4, std::get<0>(result).move_count); + EXPECT_EQ(0, std::get<0>(result).copy_count); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_doesntUnecessarilyMoveValues) { + CopyCounting a; + auto result = tuple_map( + std::tuple(std::move(a)), + [](CopyCounting&& a) -> CopyCounting&& { return std::move(a); }); + static_assert( + std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(&a, &std::get<0>(result)); + EXPECT_EQ(0, std::get<0>(result).move_count); + EXPECT_EQ(0, std::get<0>(result).copy_count); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleMap_canBeUsedWithAutoLambdas) { + struct A final { + int32_t func() { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + return 5; + } + }; + struct B final { + std::string func() { + return "5"; + } + }; + auto result = + tuple_map(std::make_tuple(A(), B()), [](auto a) { return a.func(); }); + static_assert( + std::is_same, decltype(result)>::value, + ""); + EXPECT_EQ(5, std::get<0>(result)); + EXPECT_EQ("5", std::get<1>(result)); +} +} // namespace test_tuple_map + namespace test_tuple_concat { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_zerotuples) { - auto result = tuple_concat(); - static_assert(std::is_same, decltype(result)>::value, ""); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_oneemptytuple) { - auto result = tuple_concat(std::tuple<>()); - static_assert(std::is_same, decltype(result)>::value, ""); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_onenonemptytuple) { - auto result = tuple_concat(std::tuple(3)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(3, std::get<0>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_twotuples) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_concat(std::tuple(3, "4"), std::tuple(2.3, 15)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(3, std::get<0>(result)); - EXPECT_EQ("4", std::get<1>(result)); - EXPECT_EQ(2.3, std::get<2>(result)); - EXPECT_EQ(15, std::get<3>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_threetuples) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_concat(std::tuple(3, "4"), std::tuple(2.3, 15), std::tuple("5", 3.2)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(3, std::get<0>(result)); - EXPECT_EQ("4", std::get<1>(result)); - EXPECT_EQ(2.3, std::get<2>(result)); - EXPECT_EQ(15, std::get<3>(result)); - EXPECT_EQ("5", std::get<4>(result)); - EXPECT_EQ(static_cast(3.2), std::get<5>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_emptytupleatbeginning) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_concat(std::tuple<>(), std::tuple(2.3, 15), std::tuple("5", 3.2)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(2.3, std::get<0>(result)); - EXPECT_EQ(15, std::get<1>(result)); - EXPECT_EQ("5", std::get<2>(result)); - EXPECT_EQ(static_cast(3.2), std::get<3>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_emptytupleinmiddle) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_concat(std::tuple(2.3, 15), std::tuple<>(), std::tuple("5", 3.2)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(2.3, std::get<0>(result)); - EXPECT_EQ(15, std::get<1>(result)); - EXPECT_EQ("5", std::get<2>(result)); - EXPECT_EQ(static_cast(3.2), std::get<3>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_emptytupleatend) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto result = tuple_concat(std::tuple(2.3, 15), std::tuple("5", 3.2), std::tuple<>()); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(2.3, std::get<0>(result)); - EXPECT_EQ(15, std::get<1>(result)); - EXPECT_EQ("5", std::get<2>(result)); - EXPECT_EQ(static_cast(3.2), std::get<3>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_workswithreferencesandpointers) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - double val1 = 2.3; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - int16_t val2 = 15; - std::string val3 = "hello"; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float val4 = 3.2; - auto result = tuple_concat(std::tuple(val1, val2), std::tuple(std::move(val3), &val4)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(2.3, std::get<0>(result)); - EXPECT_EQ(&val1, &std::get<0>(result)); - EXPECT_EQ(15, std::get<1>(result)); - EXPECT_EQ(&val2, &std::get<1>(result)); - EXPECT_EQ("hello", std::get<2>(result)); - EXPECT_EQ(&val3, &std::get<2>(result)); - EXPECT_EQ(static_cast(3.2), *std::get<3>(result)); - EXPECT_EQ(&val4, std::get<3>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_worksWithMovableOnlyTypes) { - auto result = tuple_concat(std::tuple(1, 2), std::tuple(3)); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(MovableOnly(1), std::get<0>(result)); - EXPECT_EQ(MovableOnly(2), std::get<1>(result)); - EXPECT_EQ(MovableOnly(3), std::get<2>(result)); - } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(MetaprogrammingTest, TupleConcat_doesntCopyMoreThanNecessary) { - auto result = tuple_concat(std::tuple(CopyCounting(), CopyCounting()), std::tuple(CopyCounting()), std::tuple(CopyCounting())); - static_assert(std::is_same, decltype(result)>::value, ""); - EXPECT_EQ(0, std::get<0>(result).copy_count); - EXPECT_EQ(0, std::get<1>(result).copy_count); - EXPECT_EQ(0, std::get<2>(result).copy_count); - EXPECT_EQ(0, std::get<3>(result).copy_count); - EXPECT_EQ(2, std::get<0>(result).move_count); - EXPECT_EQ(2, std::get<1>(result).move_count); - EXPECT_EQ(2, std::get<2>(result).move_count); - EXPECT_EQ(2, std::get<3>(result).move_count); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_zerotuples) { + auto result = tuple_concat(); + static_assert(std::is_same, decltype(result)>::value, ""); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_oneemptytuple) { + auto result = tuple_concat(std::tuple<>()); + static_assert(std::is_same, decltype(result)>::value, ""); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_onenonemptytuple) { + auto result = tuple_concat(std::tuple(3)); + static_assert(std::is_same, decltype(result)>::value, ""); + EXPECT_EQ(3, std::get<0>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_twotuples) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_concat( + std::tuple(3, "4"), + std::tuple(2.3, 15)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(3, std::get<0>(result)); + EXPECT_EQ("4", std::get<1>(result)); + EXPECT_EQ(2.3, std::get<2>(result)); + EXPECT_EQ(15, std::get<3>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_threetuples) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_concat( + std::tuple(3, "4"), + std::tuple(2.3, 15), + std::tuple("5", 3.2)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(3, std::get<0>(result)); + EXPECT_EQ("4", std::get<1>(result)); + EXPECT_EQ(2.3, std::get<2>(result)); + EXPECT_EQ(15, std::get<3>(result)); + EXPECT_EQ("5", std::get<4>(result)); + EXPECT_EQ(static_cast(3.2), std::get<5>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_emptytupleatbeginning) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_concat( + std::tuple<>(), + std::tuple(2.3, 15), + std::tuple("5", 3.2)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(2.3, std::get<0>(result)); + EXPECT_EQ(15, std::get<1>(result)); + EXPECT_EQ("5", std::get<2>(result)); + EXPECT_EQ(static_cast(3.2), std::get<3>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_emptytupleinmiddle) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_concat( + std::tuple(2.3, 15), + std::tuple<>(), + std::tuple("5", 3.2)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(2.3, std::get<0>(result)); + EXPECT_EQ(15, std::get<1>(result)); + EXPECT_EQ("5", std::get<2>(result)); + EXPECT_EQ(static_cast(3.2), std::get<3>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_emptytupleatend) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto result = tuple_concat( + std::tuple(2.3, 15), + std::tuple("5", 3.2), + std::tuple<>()); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(2.3, std::get<0>(result)); + EXPECT_EQ(15, std::get<1>(result)); + EXPECT_EQ("5", std::get<2>(result)); + EXPECT_EQ(static_cast(3.2), std::get<3>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_workswithreferencesandpointers) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + double val1 = 2.3; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + int16_t val2 = 15; + std::string val3 = "hello"; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float val4 = 3.2; + auto result = tuple_concat( + std::tuple(val1, val2), + std::tuple(std::move(val3), &val4)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(2.3, std::get<0>(result)); + EXPECT_EQ(&val1, &std::get<0>(result)); + EXPECT_EQ(15, std::get<1>(result)); + EXPECT_EQ(&val2, &std::get<1>(result)); + EXPECT_EQ("hello", std::get<2>(result)); + EXPECT_EQ(&val3, &std::get<2>(result)); + EXPECT_EQ(static_cast(3.2), *std::get<3>(result)); + EXPECT_EQ(&val4, std::get<3>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_worksWithMovableOnlyTypes) { + auto result = tuple_concat( + std::tuple(1, 2), std::tuple(3)); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(MovableOnly(1), std::get<0>(result)); + EXPECT_EQ(MovableOnly(2), std::get<1>(result)); + EXPECT_EQ(MovableOnly(3), std::get<2>(result)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MetaprogrammingTest, TupleConcat_doesntCopyMoreThanNecessary) { + auto result = tuple_concat( + std::tuple(CopyCounting(), CopyCounting()), + std::tuple(CopyCounting()), + std::tuple(CopyCounting())); + static_assert( + std::is_same< + std::tuple, + decltype(result)>::value, + ""); + EXPECT_EQ(0, std::get<0>(result).copy_count); + EXPECT_EQ(0, std::get<1>(result).copy_count); + EXPECT_EQ(0, std::get<2>(result).copy_count); + EXPECT_EQ(0, std::get<3>(result).copy_count); + EXPECT_EQ(2, std::get<0>(result).move_count); + EXPECT_EQ(2, std::get<1>(result).move_count); + EXPECT_EQ(2, std::get<2>(result).move_count); + EXPECT_EQ(2, std::get<3>(result).move_count); +} +} // namespace test_tuple_concat + namespace test_concat_iseq { - using std::index_sequence; - using std::integer_sequence; - static_assert(std::is_same, concat_iseq_t<>>::value, ""); - static_assert(std::is_same, concat_iseq_t>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); - static_assert(std::is_same, concat_iseq_t>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<4>>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<4>, index_sequence<>>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<2>>>::value, ""); - static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<>>>::value, ""); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<9>>>::value, ""); +using std::index_sequence; +using std::integer_sequence; +static_assert(std::is_same, concat_iseq_t<>>::value, ""); +static_assert( + std::is_same, concat_iseq_t>>::value, + ""); +static_assert( + std::is_same< + index_sequence<>, + concat_iseq_t, index_sequence<>>>::value, + ""); +static_assert( + std::is_same, concat_iseq_t>>::value, + ""); +static_assert( + std::is_same< + index_sequence<4>, + concat_iseq_t, index_sequence<>>>::value, + ""); +static_assert( + std::is_same< + index_sequence<4>, + concat_iseq_t, index_sequence<4>>>::value, + ""); +static_assert( + std::is_same< + index_sequence<4>, + concat_iseq_t, index_sequence<4>, index_sequence<>>>:: + value, + ""); +static_assert( + std::is_same< + index_sequence<4, 2>, + concat_iseq_t, index_sequence<2>>>::value, + ""); +static_assert( + std::is_same< + index_sequence<4, 2>, + concat_iseq_t< + index_sequence<>, + index_sequence<4, 2>, + index_sequence<>>>::value, + ""); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static_assert( + std::is_same< + index_sequence<4, 2, 9>, + concat_iseq_t< + index_sequence<>, + index_sequence<4, 2>, + index_sequence<9>>>::value, + ""); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - static_assert(std::is_same, concat_iseq_t, integer_sequence>>::value, ""); -} +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static_assert( + std::is_same< + integer_sequence, + concat_iseq_t< + integer_sequence, + integer_sequence>>::value, + ""); +} // namespace test_concat_iseq - -} +} // namespace diff --git a/c10/test/util/TypeIndex_test.cpp b/c10/test/util/TypeIndex_test.cpp index 085be909113..2f0c757a8a3 100644 --- a/c10/test/util/TypeIndex_test.cpp +++ b/c10/test/util/TypeIndex_test.cpp @@ -61,12 +61,10 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, TopLevelName) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name().find("Dummy") - ); -} + EXPECT_NE( + string_view::npos, get_fully_qualified_type_name().find("Dummy")); } +} // namespace test_top_level_name namespace test_nested_name { struct Dummy final {}; @@ -79,10 +77,9 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, NestedName) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name().find("test_nested_name::Dummy") - ); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name().find("test_nested_name::Dummy")); } } // namespace test_nested_name @@ -105,16 +102,14 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, TypeTemplateParameter) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name>().find( - "test_type_template_parameter::Outer") - ); - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name>().find( - "test_type_template_parameter::Inner") - ); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name>().find( + "test_type_template_parameter::Outer")); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name>().find( + "test_type_template_parameter::Inner")); } } // namespace test_type_template_parameter @@ -131,10 +126,9 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, NonTypeTemplateParameter) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name>().find("38474355") - ); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name>().find("38474355")); } } // namespace test_nontype_template_parameter @@ -164,21 +158,18 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, TypeComputationsAreResolved) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name::type>().find("int") - ); - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name::type>().find("*") - ); - // but with remove_pointer applied, there is no '*' in the type name anymore - EXPECT_EQ( - string_view::npos, - get_fully_qualified_type_name< - typename std::remove_pointer::type>::type>() - .find("*") - ); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name::type>().find("int")); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name::type>().find("*")); + // but with remove_pointer applied, there is no '*' in the type name anymore + EXPECT_EQ( + string_view::npos, + get_fully_qualified_type_name< + typename std::remove_pointer::type>::type>() + .find("*")); } struct Functor final { @@ -193,11 +184,10 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, FunctionTypeComputationsAreResolved) { - EXPECT_EQ( - get_fully_qualified_type_name&)>(), - get_fully_qualified_type_name< - typename c10::guts::infer_function_traits_t::func_type>() - ); + EXPECT_EQ( + get_fully_qualified_type_name&)>(), + get_fully_qualified_type_name< + typename c10::guts::infer_function_traits_t::func_type>()); } } // namespace test_type_computations_are_resolved @@ -218,16 +208,14 @@ static_assert( #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TypeIndex, FunctionArgumentsAndReturns) { - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name().find( - "test_function_arguments_and_returns::Dummy") - ); - EXPECT_NE( - string_view::npos, - get_fully_qualified_type_name().find( - "test_function_arguments_and_returns::Dummy") - ); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name().find( + "test_function_arguments_and_returns::Dummy")); + EXPECT_NE( + string_view::npos, + get_fully_qualified_type_name().find( + "test_function_arguments_and_returns::Dummy")); } } // namespace test_function_arguments_and_returns } // namespace diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp index f1468e824bf..0b38aeb3862 100644 --- a/c10/test/util/TypeList_test.cpp +++ b/c10/test/util/TypeList_test.cpp @@ -5,204 +5,382 @@ using namespace c10::guts::typelist; namespace test_size { - class MyClass {}; - static_assert(0 == size>::value, ""); - static_assert(1 == size>::value, ""); - static_assert(3 == size>::value, ""); -} +class MyClass {}; +static_assert(0 == size>::value, ""); +static_assert(1 == size>::value, ""); +static_assert(3 == size>::value, ""); +} // namespace test_size namespace test_from_tuple { - class MyClass {}; - static_assert(std::is_same, from_tuple_t>>::value, ""); - static_assert(std::is_same, from_tuple_t>>::value, ""); -} +class MyClass {}; +static_assert( + std::is_same< + typelist, + from_tuple_t>>::value, + ""); +static_assert(std::is_same, from_tuple_t>>::value, ""); +} // namespace test_from_tuple namespace test_to_tuple { - class MyClass {}; - static_assert(std::is_same, to_tuple_t>>::value, ""); - static_assert(std::is_same, to_tuple_t>>::value, ""); -} +class MyClass {}; +static_assert( + std::is_same< + std::tuple, + to_tuple_t>>::value, + ""); +static_assert(std::is_same, to_tuple_t>>::value, ""); +} // namespace test_to_tuple namespace test_concat { - class MyClass {}; - static_assert(std::is_same, concat_t<>>::value, ""); - static_assert(std::is_same, concat_t>>::value, ""); - static_assert(std::is_same, concat_t, typelist<>>>::value, ""); - static_assert(std::is_same, concat_t>>::value, ""); - static_assert(std::is_same, concat_t, typelist<>>>::value, ""); - static_assert(std::is_same, concat_t, typelist>>::value, ""); - static_assert(std::is_same, concat_t, typelist, typelist<>>>::value, ""); - static_assert(std::is_same, concat_t, typelist>>::value, ""); - static_assert(std::is_same, concat_t, typelist, typelist<>>>::value, ""); - static_assert(std::is_same, concat_t, typelist, typelist>>::value, ""); -} +class MyClass {}; +static_assert(std::is_same, concat_t<>>::value, ""); +static_assert(std::is_same, concat_t>>::value, ""); +static_assert( + std::is_same, concat_t, typelist<>>>::value, + ""); +static_assert(std::is_same, concat_t>>::value, ""); +static_assert( + std::is_same, concat_t, typelist<>>>::value, + ""); +static_assert( + std::is_same, concat_t, typelist>>::value, + ""); +static_assert( + std::is_same< + typelist, + concat_t, typelist, typelist<>>>::value, + ""); +static_assert( + std::is_same< + typelist, + concat_t, typelist>>::value, + ""); +static_assert( + std::is_same< + typelist, + concat_t, typelist, typelist<>>>::value, + ""); +static_assert( + std::is_same< + typelist, + concat_t< + typelist<>, + typelist, + typelist>>::value, + ""); +} // namespace test_concat namespace test_filter { - class MyClass {}; - static_assert(std::is_same, filter_t>>::value, ""); - static_assert(std::is_same, filter_t>>::value, ""); - static_assert(std::is_same, filter_t>>::value, ""); -} +class MyClass {}; +static_assert( + std::is_same, filter_t>>::value, + ""); +static_assert( + std::is_same< + typelist<>, + filter_t>>:: + value, + ""); +static_assert( + std::is_same< + typelist, + filter_t< + std::is_reference, + typelist>>::value, + ""); +} // namespace test_filter namespace test_count_if { - class MyClass final {}; - static_assert(count_if>::value == 2, ""); - static_assert(count_if>::value == 0, ""); - static_assert(count_if>::value == 0, ""); -} +class MyClass final {}; +static_assert( + count_if< + std::is_reference, + typelist>::value == 2, + ""); +static_assert(count_if>::value == 0, ""); +static_assert(count_if>::value == 0, ""); +} // namespace test_count_if namespace test_true_for_each_type { - template class Test; - class MyClass {}; - static_assert(all>::value, ""); - static_assert(!all>::value, ""); - static_assert(all>::value, ""); -} +template +class Test; +class MyClass {}; +static_assert( + all>::value, + ""); +static_assert( + !all>::value, + ""); +static_assert(all>::value, ""); +} // namespace test_true_for_each_type namespace test_true_for_any_type { - template class Test; - class MyClass {}; - static_assert(true_for_any_type>::value, ""); - static_assert(true_for_any_type>::value, ""); - static_assert(!true_for_any_type>::value, ""); - static_assert(!true_for_any_type>::value, ""); -} +template +class Test; +class MyClass {}; +static_assert( + true_for_any_type< + std::is_reference, + typelist>::value, + ""); +static_assert( + true_for_any_type< + std::is_reference, + typelist>::value, + ""); +static_assert( + !true_for_any_type< + std::is_reference, + typelist>::value, + ""); +static_assert(!true_for_any_type>::value, ""); +} // namespace test_true_for_any_type namespace test_map { - class MyClass {}; - static_assert(std::is_same, map_t>>::value, ""); - static_assert(std::is_same, map_t>>::value, ""); - static_assert(std::is_same, map_t>>::value, ""); -} +class MyClass {}; +static_assert( + std::is_same, map_t>>:: + value, + ""); +static_assert( + std::is_same< + typelist, + map_t>>::value, + ""); +static_assert( + std::is_same< + typelist, + map_t< + std::add_lvalue_reference_t, + typelist>>::value, + ""); +} // namespace test_map namespace test_head { - class MyClass {}; - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); -} +class MyClass {}; +static_assert(std::is_same>>::value, ""); +static_assert( + std::is_same>>:: + value, + ""); +static_assert( + std::is_same>>::value, + ""); +static_assert(std::is_same>>::value, ""); +} // namespace test_head namespace test_head_with_default { - class MyClass {}; - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); - static_assert(std::is_same>>::value, ""); -} +class MyClass {}; +static_assert( + std::is_same>>::value, + ""); +static_assert( + std::is_same< + const MyClass&, + head_with_default_t>>::value, + ""); +static_assert( + std::is_same< + MyClass&&, + head_with_default_t>>::value, + ""); +static_assert( + std::is_same>>::value, + ""); +static_assert( + std::is_same>>::value, + ""); +} // namespace test_head_with_default namespace test_reverse { - class MyClass {}; - static_assert(std::is_same< - typelist, - reverse_t> - >::value, ""); - static_assert(std::is_same< - typelist<>, - reverse_t> - >::value, ""); -} +class MyClass {}; +static_assert( + std::is_same< + typelist, + reverse_t>>::value, + ""); +static_assert(std::is_same, reverse_t>>::value, ""); +} // namespace test_reverse namespace test_map_types_to_values { - struct map_to_size { - template constexpr size_t operator()(T) const {return sizeof(typename T::type);} - }; +struct map_to_size { + template + constexpr size_t operator()(T) const { + return sizeof(typename T::type); + } +}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(TypeListTest, MapTypesToValues_sametype) { - auto sizes = - map_types_to_values>(map_to_size()); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::tuple expected(8, 1, 4); - static_assert(std::is_same::value, ""); - EXPECT_EQ(expected, sizes); - } - - struct map_make_shared { - template std::shared_ptr operator()(T) { - return std::make_shared(); - } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(TypeListTest, MapTypesToValues_differenttypes) { - auto shared_ptrs = - map_types_to_values>(map_make_shared()); - static_assert(std::is_same, std::shared_ptr>, decltype(shared_ptrs)>::value, ""); - } - - struct Class1 {static int func() {return 3;}}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - struct Class2 {static double func() {return 2.0;}}; - - struct mapper_call_func { - template decltype(auto) operator()(T) { return T::type::func(); } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(TypeListTest, MapTypesToValues_members) { - auto result = - map_types_to_values>(mapper_call_func()); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::tuple expected(3, 2.0); - static_assert(std::is_same::value, ""); - EXPECT_EQ(expected, result); - } - - struct mapper_call_nonexistent_function { - template decltype(auto) operator()(T) { return T::type::this_doesnt_exist(); } - }; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(TypeListTest, MapTypesToValues_empty) { - auto result = - map_types_to_values>(mapper_call_nonexistent_function()); - std::tuple<> expected; - static_assert(std::is_same::value, ""); - EXPECT_EQ(expected, result); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(TypeListTest, MapTypesToValues_sametype) { + auto sizes = + map_types_to_values>(map_to_size()); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::tuple expected(8, 1, 4); + static_assert(std::is_same::value, ""); + EXPECT_EQ(expected, sizes); } +struct map_make_shared { + template + std::shared_ptr operator()(T) { + return std::make_shared(); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(TypeListTest, MapTypesToValues_differenttypes) { + auto shared_ptrs = + map_types_to_values>(map_make_shared()); + static_assert( + std::is_same< + std::tuple, std::shared_ptr>, + decltype(shared_ptrs)>::value, + ""); +} + +struct Class1 { + static int func() { + return 3; + } +}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +struct Class2 { + static double func() { + return 2.0; + } +}; + +struct mapper_call_func { + template + decltype(auto) operator()(T) { + return T::type::func(); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(TypeListTest, MapTypesToValues_members) { + auto result = + map_types_to_values>(mapper_call_func()); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::tuple expected(3, 2.0); + static_assert(std::is_same::value, ""); + EXPECT_EQ(expected, result); +} + +struct mapper_call_nonexistent_function { + template + decltype(auto) operator()(T) { + return T::type::this_doesnt_exist(); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(TypeListTest, MapTypesToValues_empty) { + auto result = + map_types_to_values>(mapper_call_nonexistent_function()); + std::tuple<> expected; + static_assert(std::is_same::value, ""); + EXPECT_EQ(expected, result); +} +} // namespace test_map_types_to_values + namespace test_find_if { - static_assert(0 == find_if, std::is_reference>::value, ""); - static_assert(0 == find_if, std::is_reference>::value, ""); - static_assert(2 == find_if, std::is_reference>::value, ""); - static_assert(3 == find_if, std::is_reference>::value, ""); -} +static_assert(0 == find_if, std::is_reference>::value, ""); +static_assert( + 0 == find_if, std::is_reference>::value, + ""); +static_assert( + 2 == find_if, std::is_reference>::value, + ""); +static_assert( + 3 == find_if, std::is_reference>::value, + ""); +} // namespace test_find_if namespace test_contains { - static_assert(contains, double>::value, ""); - static_assert(contains, double>::value, ""); - static_assert(!contains, float>::value, ""); - static_assert(!contains, double>::value, ""); -} +static_assert(contains, double>::value, ""); +static_assert(contains, double>::value, ""); +static_assert(!contains, float>::value, ""); +static_assert(!contains, double>::value, ""); +} // namespace test_contains namespace test_take { - static_assert(std::is_same, take_t, 0>>::value, ""); - static_assert(std::is_same, take_t, 0>>::value, ""); - static_assert(std::is_same, take_t, 1>>::value, ""); - static_assert(std::is_same, take_t, 0>>::value, ""); - static_assert(std::is_same, take_t, 1>>::value, ""); - static_assert(std::is_same, take_t, 2>>::value, ""); -} +static_assert(std::is_same, take_t, 0>>::value, ""); +static_assert( + std::is_same, take_t, 0>>::value, + ""); +static_assert( + std::is_same, take_t, 1>>::value, + ""); +static_assert( + std::is_same, take_t, 0>>::value, + ""); +static_assert( + std::is_same, take_t, 1>>:: + value, + ""); +static_assert( + std::is_same< + typelist, + take_t, 2>>::value, + ""); +} // namespace test_take namespace test_drop { - static_assert(std::is_same, drop_t, 0>>::value, ""); - static_assert(std::is_same, drop_t, 0>>::value, ""); - static_assert(std::is_same, drop_t, 1>>::value, ""); - static_assert(std::is_same, drop_t, 0>>::value, ""); - static_assert(std::is_same, drop_t, 1>>::value, ""); - static_assert(std::is_same, drop_t, 2>>::value, ""); -} +static_assert(std::is_same, drop_t, 0>>::value, ""); +static_assert( + std::is_same, drop_t, 0>>::value, + ""); +static_assert( + std::is_same, drop_t, 1>>::value, + ""); +static_assert( + std::is_same< + typelist, + drop_t, 0>>::value, + ""); +static_assert( + std::is_same, drop_t, 1>>:: + value, + ""); +static_assert( + std::is_same, drop_t, 2>>::value, + ""); +} // namespace test_drop namespace test_drop_if_nonempty { - static_assert(std::is_same, drop_if_nonempty_t, 0>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 0>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 1>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 0>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 1>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 2>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 1>>::value, ""); - static_assert(std::is_same, drop_if_nonempty_t, 3>>::value, ""); -} +static_assert( + std::is_same, drop_if_nonempty_t, 0>>::value, + ""); +static_assert( + std::is_same, drop_if_nonempty_t, 0>>:: + value, + ""); +static_assert( + std::is_same, drop_if_nonempty_t, 1>>::value, + ""); +static_assert( + std::is_same< + typelist, + drop_if_nonempty_t, 0>>::value, + ""); +static_assert( + std::is_same< + typelist, + drop_if_nonempty_t, 1>>::value, + ""); +static_assert( + std::is_same< + typelist<>, + drop_if_nonempty_t, 2>>::value, + ""); +static_assert( + std::is_same, drop_if_nonempty_t, 1>>::value, + ""); +static_assert( + std::is_same< + typelist<>, + drop_if_nonempty_t, 3>>::value, + ""); +} // namespace test_drop_if_nonempty diff --git a/c10/test/util/TypeTraits_test.cpp b/c10/test/util/TypeTraits_test.cpp index 580f6cacfd6..356b4758564 100644 --- a/c10/test/util/TypeTraits_test.cpp +++ b/c10/test/util/TypeTraits_test.cpp @@ -6,152 +6,190 @@ using namespace c10::guts; namespace { namespace test_is_equality_comparable { - class NotEqualityComparable {}; - class EqualityComparable {}; +class NotEqualityComparable {}; +class EqualityComparable {}; - inline bool operator==(const EqualityComparable &, const EqualityComparable &) { return false; } - - static_assert(!is_equality_comparable::value, ""); - static_assert(is_equality_comparable::value, ""); - static_assert(is_equality_comparable::value, ""); - - // v_ just exists to silence a compiler warning about operator==(EqualityComparable, EqualityComparable) not being needed - const bool v_ = EqualityComparable() == EqualityComparable(); +inline bool operator==(const EqualityComparable&, const EqualityComparable&) { + return false; } +static_assert(!is_equality_comparable::value, ""); +static_assert(is_equality_comparable::value, ""); +static_assert(is_equality_comparable::value, ""); + +// v_ just exists to silence a compiler warning about +// operator==(EqualityComparable, EqualityComparable) not being needed +const bool v_ = EqualityComparable() == EqualityComparable(); +} // namespace test_is_equality_comparable + namespace test_is_hashable { - class NotHashable {}; - class Hashable {}; -} -} +class NotHashable {}; +class Hashable {}; +} // namespace test_is_hashable +} // namespace namespace std { - template<> struct hash final { - size_t operator()(const test_is_hashable::Hashable &) { return 0; } - }; -} +template <> +struct hash final { + size_t operator()(const test_is_hashable::Hashable&) { + return 0; + } +}; +} // namespace std namespace { namespace test_is_hashable { - static_assert(is_hashable::value, ""); - static_assert(is_hashable::value, ""); - static_assert(!is_hashable::value, ""); -} +static_assert(is_hashable::value, ""); +static_assert(is_hashable::value, ""); +static_assert(!is_hashable::value, ""); +} // namespace test_is_hashable namespace test_is_function_type { - class MyClass {}; - struct Functor { - void operator()() {} - }; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - auto lambda = [] () {}; - // func() and func__ just exists to silence a compiler warning about lambda being unused - bool func() { - lambda(); - return true; - } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - bool func__ = func(); - - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - static_assert(is_function_type::value, ""); - - static_assert(!is_function_type::value, ""); - static_assert(!is_function_type::value, ""); - static_assert(!is_function_type::value, ""); - static_assert(!is_function_type::value, ""); - static_assert(!is_function_type::value, ""); - static_assert(!is_function_type::value, ""); - - static_assert(!is_function_type::value, "function pointers aren't plain functions"); - static_assert(!is_function_type::value, "Functors aren't plain functions"); - static_assert(!is_function_type::value, "Lambdas aren't plain functions"); +class MyClass {}; +struct Functor { + void operator()() {} +}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +auto lambda = []() {}; +// func() and func__ just exists to silence a compiler warning about lambda +// being unused +bool func() { + lambda(); + return true; } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +bool func__ = func(); + +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type::value, ""); +static_assert(is_function_type < MyClass && () > ::value, ""); +static_assert(is_function_type < MyClass && (MyClass &&) > ::value, ""); +static_assert(is_function_type::value, ""); + +static_assert(!is_function_type::value, ""); +static_assert(!is_function_type::value, ""); +static_assert(!is_function_type::value, ""); +static_assert(!is_function_type::value, ""); +static_assert(!is_function_type::value, ""); +static_assert(!is_function_type::value, ""); + +static_assert( + !is_function_type::value, + "function pointers aren't plain functions"); +static_assert( + !is_function_type::value, + "Functors aren't plain functions"); +static_assert( + !is_function_type::value, + "Lambdas aren't plain functions"); +} // namespace test_is_function_type namespace test_is_instantiation_of { - class MyClass {}; - template class Single {}; - template class Double {}; - template class Multiple {}; +class MyClass {}; +template +class Single {}; +template +class Double {}; +template +class Multiple {}; - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); - static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert( + is_instantiation_of>::value, + ""); +static_assert(is_instantiation_of>::value, ""); +static_assert(is_instantiation_of>::value, ""); +static_assert( + is_instantiation_of>::value, + ""); +static_assert( + is_instantiation_of>::value, + ""); +static_assert( + is_instantiation_of>:: + value, + ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); - static_assert(!is_instantiation_of>::value, ""); -} +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +static_assert(!is_instantiation_of>::value, ""); +} // namespace test_is_instantiation_of namespace test_is_type_condition { - template class NotATypeCondition {}; - static_assert(is_type_condition::value, ""); - static_assert(!is_type_condition::value, ""); -} -} +template +class NotATypeCondition {}; +static_assert(is_type_condition::value, ""); +static_assert(!is_type_condition::value, ""); +} // namespace test_is_type_condition +} // namespace namespace test_lambda_is_stateless { - template - struct MyStatelessFunctor final { - Result operator()(Args...) {} - }; +template +struct MyStatelessFunctor final { + Result operator()(Args...) {} +}; - template - struct MyStatelessConstFunctor final { - Result operator()(Args...) const {} - }; +template +struct MyStatelessConstFunctor final { + Result operator()(Args...) const {} +}; - void func() { - auto stateless_lambda = [] (int a) {return a;}; - static_assert(is_stateless_lambda::value, ""); +void func() { + auto stateless_lambda = [](int a) { return a; }; + static_assert(is_stateless_lambda::value, ""); - int b = 4; - auto stateful_lambda_1 = [&] (int a) {return a + b;}; - static_assert(!is_stateless_lambda::value, ""); + int b = 4; + auto stateful_lambda_1 = [&](int a) { return a + b; }; + static_assert(!is_stateless_lambda::value, ""); - auto stateful_lambda_2 = [=] (int a) {return a + b;}; - static_assert(!is_stateless_lambda::value, ""); + auto stateful_lambda_2 = [=](int a) { return a + b; }; + static_assert(!is_stateless_lambda::value, ""); - auto stateful_lambda_3 = [b] (int a) {return a + b;}; - static_assert(!is_stateless_lambda::value, ""); + auto stateful_lambda_3 = [b](int a) { return a + b; }; + static_assert(!is_stateless_lambda::value, ""); - static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); - static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); - static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); - static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); + static_assert( + !is_stateless_lambda>::value, + "even if stateless, a functor is not a lambda, so it's false"); + static_assert( + !is_stateless_lambda>::value, + "even if stateless, a functor is not a lambda, so it's false"); + static_assert( + !is_stateless_lambda>::value, + "even if stateless, a functor is not a lambda, so it's false"); + static_assert( + !is_stateless_lambda>::value, + "even if stateless, a functor is not a lambda, so it's false"); - class Dummy final {}; - static_assert(!is_stateless_lambda::value, "A non-functor type is also not a lambda"); + class Dummy final {}; + static_assert( + !is_stateless_lambda::value, + "A non-functor type is also not a lambda"); - static_assert(!is_stateless_lambda::value, "An int is not a lambda"); + static_assert(!is_stateless_lambda::value, "An int is not a lambda"); - using Func = int(int); - static_assert(!is_stateless_lambda::value, "A function is not a lambda"); - static_assert(!is_stateless_lambda::value, "A function pointer is not a lambda"); - } + using Func = int(int); + static_assert( + !is_stateless_lambda::value, "A function is not a lambda"); + static_assert( + !is_stateless_lambda::value, "A function pointer is not a lambda"); } +} // namespace test_lambda_is_stateless diff --git a/c10/test/util/accumulate_test.cpp b/c10/test/util/accumulate_test.cpp index 1b41aebc24f..02de9ea950b 100644 --- a/c10/test/util/accumulate_test.cpp +++ b/c10/test/util/accumulate_test.cpp @@ -11,70 +11,73 @@ using namespace ::testing; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(accumulate_test, vector_test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::vector ints = {1, 2, 3, 4, 5}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::vector ints = {1, 2, 3, 4, 5}; - EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5); - EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5); + EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5); + EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5); - EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5); - EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5); + EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5); + EXPECT_EQ( + c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5); - EXPECT_EQ(c10::sum_integers(ints.begin()+1, ints.end()-1), 2+3+4); - EXPECT_EQ(c10::multiply_integers(ints.begin()+1, ints.end()-1), 2*3*4); + EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4); + EXPECT_EQ( + c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4); - EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5); - EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3); - EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4); - EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4); + EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5); + EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3); + EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4); + EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(accumulate_test, list_test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::list ints = {1, 2, 3, 4, 5}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::list ints = {1, 2, 3, 4, 5}; - EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5); - EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5); + EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5); + EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5); - EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5); - EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5); + EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5); + EXPECT_EQ( + c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5); - EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5); - EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3); - EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4); - EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4); + EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5); + EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3); + EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4); + EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(accumulate_test, base_cases) { - std::vector ints = {}; + std::vector ints = {}; - EXPECT_EQ(c10::sum_integers(ints), 0); - EXPECT_EQ(c10::multiply_integers(ints), 1); + EXPECT_EQ(c10::sum_integers(ints), 0); + EXPECT_EQ(c10::multiply_integers(ints), 1); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(accumulate_test, errors) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::vector ints = {1,2,3,4,5}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::vector ints = {1, 2, 3, 4, 5}; - #ifndef NDEBUG - EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error); - #endif +#ifndef NDEBUG + EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error); +#endif - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error); - EXPECT_EQ(c10::numelements_from_dim(10, ints),1); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error); + EXPECT_EQ(c10::numelements_from_dim(10, ints), 1); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error); } diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index e6ee94c2f76..0cde27587fa 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -2,196 +2,194 @@ #include namespace { - float float_from_bytes( - uint32_t sign, - uint32_t exponent, - uint32_t fraction - ) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t bytes; - bytes = 0; - bytes |= sign; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - bytes <<= 8; - bytes |= exponent; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - bytes <<= 23; - bytes |= fraction; +float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t bytes; + bytes = 0; + bytes |= sign; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + bytes <<= 8; + bytes |= exponent; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + bytes <<= 23; + bytes |= fraction; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float res; - std::memcpy(&res, &bytes, sizeof(res)); - return res; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Conversion, FloatToBFloat16AndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Conversion, FloatToBFloat16AndBack) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - float in[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (int i = 0; i < 100; ++i) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) - in[i] = i + 1.25; - } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + c10::BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - c10::BFloat16 bfloats[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - float out[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (int i = 0; i < 100; ++i) { + bfloats[i].x = c10::detail::bits_from_f32(in[i]); + out[i] = c10::detail::f32_from_bits(bfloats[i].x); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (int i = 0; i < 100; ++i) { - bfloats[i].x = c10::detail::bits_from_f32(in[i]); - out[i] = c10::detail::f32_from_bits(bfloats[i].x); + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } +} - // The relative error should be less than 1/(2^7) since BFloat16 - // has 7 bits mantissa. - EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - float in[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (int i = 0; i < 100; ++i) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) - in[i] = i + 1.25; - } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + c10::BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - c10::BFloat16 bfloats[100]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) - float out[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (int i = 0; i < 100; ++i) { + bfloats[i].x = c10::detail::round_to_nearest_even(in[i]); + out[i] = c10::detail::f32_from_bits(bfloats[i].x); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (int i = 0; i < 100; ++i) { - bfloats[i].x = c10::detail::round_to_nearest_even(in[i]); - out[i] = c10::detail::f32_from_bits(bfloats[i].x); - - // The relative error should be less than 1/(2^7) since BFloat16 - // has 7 bits mantissa. - EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); - } + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); } +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Conversion, NaN) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); - EXPECT_TRUE(std::isnan(inNaN)); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Conversion, NaN) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); + EXPECT_TRUE(std::isnan(inNaN)); - c10::BFloat16 a = c10::BFloat16(inNaN); - float out = c10::detail::f32_from_bits(a.x); + c10::BFloat16 a = c10::BFloat16(inNaN); + float out = c10::detail::f32_from_bits(a.x); - EXPECT_TRUE(std::isnan(out)); - } + EXPECT_TRUE(std::isnan(out)); +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Conversion, Inf) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float inInf = float_from_bytes(0, 0xFF, 0); - EXPECT_TRUE(std::isinf(inInf)); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Conversion, Inf) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float inInf = float_from_bytes(0, 0xFF, 0); + EXPECT_TRUE(std::isinf(inInf)); - c10::BFloat16 a = c10::BFloat16(inInf); - float out = c10::detail::f32_from_bits(a.x); + c10::BFloat16 a = c10::BFloat16(inInf); + float out = c10::detail::f32_from_bits(a.x); - EXPECT_TRUE(std::isinf(out)); - } + EXPECT_TRUE(std::isinf(out)); +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Conversion, SmallestDenormal) { - float in = std::numeric_limits::denorm_min(); // The smallest non-zero subnormal number - c10::BFloat16 a = c10::BFloat16(in); - float out = c10::detail::f32_from_bits(a.x); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Conversion, SmallestDenormal) { + float in = std::numeric_limits::denorm_min(); // The smallest non-zero + // subnormal number + c10::BFloat16 a = c10::BFloat16(in); + float out = c10::detail::f32_from_bits(a.x); - EXPECT_FLOAT_EQ(in, out); - } + EXPECT_FLOAT_EQ(in, out); +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Math, Addition) { - // This test verifies that if only first 7 bits of float's mantissa are - // changed after addition, we should have no loss in precision. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Math, Addition) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after addition, we should have no loss in precision. - // input bits - // S | Exponent | Mantissa - // 0 | 10000000 | 10010000000000000000000 = 3.125 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float input = float_from_bytes(0, 0, 0x40480000); + // input bits + // S | Exponent | Mantissa + // 0 | 10000000 | 10010000000000000000000 = 3.125 + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float input = float_from_bytes(0, 0, 0x40480000); - // expected bits - // S | Exponent | Mantissa - // 0 | 10000001 | 10010000000000000000000 = 6.25 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float expected = float_from_bytes(0, 0, 0x40c80000); + // expected bits + // S | Exponent | Mantissa + // 0 | 10000001 | 10010000000000000000000 = 6.25 + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float expected = float_from_bytes(0, 0, 0x40c80000); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - c10::BFloat16 b; - b.x = c10::detail::bits_from_f32(input); - b = b + b; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + c10::BFloat16 b; + b.x = c10::detail::bits_from_f32(input); + b = b + b; - float res = c10::detail::f32_from_bits(b.x); - EXPECT_EQ(res, expected); - } + float res = c10::detail::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST(BFloat16Math, Subtraction) { - // This test verifies that if only first 7 bits of float's mantissa are - // changed after subtraction, we should have no loss in precision. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Math, Subtraction) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after subtraction, we should have no loss in precision. - // input bits - // S | Exponent | Mantissa - // 0 | 10000001 | 11101000000000000000000 = 7.625 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float input = float_from_bytes(0, 0, 0x40f40000); + // input bits + // S | Exponent | Mantissa + // 0 | 10000001 | 11101000000000000000000 = 7.625 + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float input = float_from_bytes(0, 0, 0x40f40000); - // expected bits - // S | Exponent | Mantissa - // 0 | 10000000 | 01010000000000000000000 = 2.625 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - float expected = float_from_bytes(0, 0, 0x40280000); + // expected bits + // S | Exponent | Mantissa + // 0 | 10000000 | 01010000000000000000000 = 2.625 + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + float expected = float_from_bytes(0, 0, 0x40280000); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - c10::BFloat16 b; - b.x = c10::detail::bits_from_f32(input); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - b = b - 5; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + c10::BFloat16 b; + b.x = c10::detail::bits_from_f32(input); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + b = b - 5; - float res = c10::detail::f32_from_bits(b.x); - EXPECT_EQ(res, expected); - } + float res = c10::detail::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} - float BinaryToFloat(uint32_t bytes) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float res; - std::memcpy(&res, &bytes, sizeof(res)); - return res; - } +float BinaryToFloat(uint32_t bytes) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} - struct BFloat16TestParam { - uint32_t input; - uint16_t rne; - }; +struct BFloat16TestParam { + uint32_t input; + uint16_t rne; +}; - class BFloat16Test : public ::testing::Test, - public ::testing::WithParamInterface { - }; +class BFloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - TEST_P(BFloat16Test, BFloat16RNETest) { - float value = BinaryToFloat(GetParam().input); - uint16_t rounded = c10::detail::round_to_nearest_even(value); - EXPECT_EQ(GetParam().rne, rounded); - } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST_P(BFloat16Test, BFloat16RNETest) { + float value = BinaryToFloat(GetParam().input); + uint16_t rounded = c10::detail::round_to_nearest_even(value); + EXPECT_EQ(GetParam().rne, rounded); +} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - INSTANTIATE_TEST_CASE_P( - BFloat16Test_Instantiation, BFloat16Test, - ::testing::Values(BFloat16TestParam{0x3F848000, 0x3F84}, - BFloat16TestParam{0x3F848010, 0x3F85}, - BFloat16TestParam{0x3F850000, 0x3F85}, - BFloat16TestParam{0x3F858000, 0x3F86}, - BFloat16TestParam{0x3FFF8000, 0x4000})); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +INSTANTIATE_TEST_CASE_P( + BFloat16Test_Instantiation, + BFloat16Test, + ::testing::Values( + BFloat16TestParam{0x3F848000, 0x3F84}, + BFloat16TestParam{0x3F848010, 0x3F85}, + BFloat16TestParam{0x3F850000, 0x3F85}, + BFloat16TestParam{0x3F858000, 0x3F86}, + BFloat16TestParam{0x3FFF8000, 0x4000})); } // namespace diff --git a/c10/test/util/complex_math_test_common.h b/c10/test/util/complex_math_test_common.h index 4715c37c7f9..15addf68785 100644 --- a/c10/test/util/complex_math_test_common.h +++ b/c10/test/util/complex_math_test_common.h @@ -1,4 +1,5 @@ -// Warning: this file is included twice in aten/src/ATen/test/cuda_complex_math_test.cu +// Warning: this file is included twice in +// aten/src/ATen/test/cuda_complex_math_test.cu #include #include @@ -16,152 +17,152 @@ C10_DEFINE_TEST(TestExponential, IPi) { // exp(i*pi) = -1 { - c10::complex e_i_pi = std::exp(c10::complex(0, float(PI))); - C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); - C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + c10::complex e_i_pi = std::exp(c10::complex(0, float(PI))); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); } { - c10::complex e_i_pi = ::exp(c10::complex(0, float(PI))); - C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); - C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + c10::complex e_i_pi = ::exp(c10::complex(0, float(PI))); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); } { - c10::complex e_i_pi = std::exp(c10::complex(0, PI)); - C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); - C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + c10::complex e_i_pi = std::exp(c10::complex(0, PI)); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); } { - c10::complex e_i_pi = ::exp(c10::complex(0, PI)); - C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); - C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + c10::complex e_i_pi = ::exp(c10::complex(0, PI)); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); } } C10_DEFINE_TEST(TestExponential, EulerFormula) { // exp(ix) = cos(x) + i * sin(x) { - c10::complex x(0.1, 1.2); - c10::complex e = std::exp(x); - float expected_real = std::exp(x.real()) * std::cos(x.imag()); - float expected_imag = std::exp(x.real()) * std::sin(x.imag()); - C10_ASSERT_NEAR(e.real(), expected_real, tol); - C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex e = std::exp(x); + float expected_real = std::exp(x.real()) * std::cos(x.imag()); + float expected_imag = std::exp(x.real()) * std::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex e = ::exp(x); - float expected_real = ::exp(x.real()) * ::cos(x.imag()); - float expected_imag = ::exp(x.real()) * ::sin(x.imag()); - C10_ASSERT_NEAR(e.real(), expected_real, tol); - C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex e = ::exp(x); + float expected_real = ::exp(x.real()) * ::cos(x.imag()); + float expected_imag = ::exp(x.real()) * ::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex e = std::exp(x); - float expected_real = std::exp(x.real()) * std::cos(x.imag()); - float expected_imag = std::exp(x.real()) * std::sin(x.imag()); - C10_ASSERT_NEAR(e.real(), expected_real, tol); - C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex e = std::exp(x); + float expected_real = std::exp(x.real()) * std::cos(x.imag()); + float expected_imag = std::exp(x.real()) * std::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex e = ::exp(x); - float expected_real = ::exp(x.real()) * ::cos(x.imag()); - float expected_imag = ::exp(x.real()) * ::sin(x.imag()); - C10_ASSERT_NEAR(e.real(), expected_real, tol); - C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex e = ::exp(x); + float expected_real = ::exp(x.real()) * ::cos(x.imag()); + float expected_imag = ::exp(x.real()) * ::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); } } C10_DEFINE_TEST(TestLog, Definition) { // log(x) = log(r) + i*theta { - c10::complex x(1.2, 3.4); - c10::complex l = std::log(x); - float expected_real = std::log(std::abs(x)); - float expected_imag = std::arg(x); - C10_ASSERT_NEAR(l.real(), expected_real, tol); - C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + c10::complex x(1.2, 3.4); + c10::complex l = std::log(x); + float expected_real = std::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); } { - c10::complex x(1.2, 3.4); - c10::complex l = ::log(x); - float expected_real = ::log(std::abs(x)); - float expected_imag = std::arg(x); - C10_ASSERT_NEAR(l.real(), expected_real, tol); - C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + c10::complex x(1.2, 3.4); + c10::complex l = ::log(x); + float expected_real = ::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); } { - c10::complex x(1.2, 3.4); - c10::complex l = std::log(x); - float expected_real = std::log(std::abs(x)); - float expected_imag = std::arg(x); - C10_ASSERT_NEAR(l.real(), expected_real, tol); - C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + c10::complex x(1.2, 3.4); + c10::complex l = std::log(x); + float expected_real = std::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); } { - c10::complex x(1.2, 3.4); - c10::complex l = ::log(x); - float expected_real = ::log(std::abs(x)); - float expected_imag = std::arg(x); - C10_ASSERT_NEAR(l.real(), expected_real, tol); - C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + c10::complex x(1.2, 3.4); + c10::complex l = ::log(x); + float expected_real = ::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); } } C10_DEFINE_TEST(TestLog10, Rev) { // log10(10^x) = x { - c10::complex x(0.1, 1.2); - c10::complex l = std::log10(std::pow(float(10), x)); - C10_ASSERT_NEAR(l.real(), float(0.1), tol); - C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = std::log10(std::pow(float(10), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = ::log10(::pow(float(10), x)); - C10_ASSERT_NEAR(l.real(), float(0.1), tol); - C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = ::log10(::pow(float(10), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = std::log10(std::pow(double(10), x)); - C10_ASSERT_NEAR(l.real(), double(0.1), tol); - C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = std::log10(std::pow(double(10), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = ::log10(::pow(double(10), x)); - C10_ASSERT_NEAR(l.real(), double(0.1), tol); - C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = ::log10(::pow(double(10), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); } } C10_DEFINE_TEST(TestLog2, Rev) { // log2(2^x) = x { - c10::complex x(0.1, 1.2); - c10::complex l = std::log2(std::pow(float(2), x)); - C10_ASSERT_NEAR(l.real(), float(0.1), tol); - C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = std::log2(std::pow(float(2), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = ::log2(std::pow(float(2), x)); - C10_ASSERT_NEAR(l.real(), float(0.1), tol); - C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = ::log2(std::pow(float(2), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = std::log2(std::pow(double(2), x)); - C10_ASSERT_NEAR(l.real(), double(0.1), tol); - C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = std::log2(std::pow(double(2), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); } { - c10::complex x(0.1, 1.2); - c10::complex l = ::log2(std::pow(double(2), x)); - C10_ASSERT_NEAR(l.real(), double(0.1), tol); - C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + c10::complex x(0.1, 1.2); + c10::complex l = ::log2(std::pow(double(2), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); } } @@ -170,64 +171,64 @@ C10_DEFINE_TEST(TestLog2, Rev) { C10_DEFINE_TEST(TestPowSqrt, Equal) { // x^0.5 = sqrt(x) { - c10::complex x(0.1, 1.2); - c10::complex y = std::pow(x, float(0.5)); - c10::complex z = std::sqrt(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, float(0.5)); + c10::complex z = std::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::pow(x, float(0.5)); - c10::complex z = ::sqrt(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, float(0.5)); + c10::complex z = ::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::pow(x, double(0.5)); - c10::complex z = std::sqrt(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, double(0.5)); + c10::complex z = std::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::pow(x, double(0.5)); - c10::complex z = ::sqrt(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, double(0.5)); + c10::complex z = ::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } } C10_DEFINE_TEST(TestPow, Square) { // x^2 = x * x { - c10::complex x(0.1, 1.2); - c10::complex y = std::pow(x, float(2)); - c10::complex z = x * x; - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, float(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::pow(x, float(2)); - c10::complex z = x * x; - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, float(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::pow(x, double(2)); - c10::complex z = x * x; - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, double(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::pow(x, double(2)); - c10::complex z = x * x; - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, double(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } } @@ -237,132 +238,132 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) { // sin(x + i * y) = sin(x) * cosh(y) + i * cos(x) * sinh(y) // cos(x + i * y) = cos(x) * cosh(y) - i * sin(x) * sinh(y) { - c10::complex x(0.1, 1.2); - c10::complex y = std::sin(x); - float expected_real = std::sin(x.real()) * std::cosh(x.imag()); - float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::sin(x); + float expected_real = std::sin(x.real()) * std::cosh(x.imag()); + float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::sin(x); - float expected_real = ::sin(x.real()) * ::cosh(x.imag()); - float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::sin(x); + float expected_real = ::sin(x.real()) * ::cosh(x.imag()); + float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::cos(x); - float expected_real = std::cos(x.real()) * std::cosh(x.imag()); - float expected_imag = - std::sin(x.real()) * std::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::cos(x); + float expected_real = std::cos(x.real()) * std::cosh(x.imag()); + float expected_imag = -std::sin(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::cos(x); - float expected_real = ::cos(x.real()) * ::cosh(x.imag()); - float expected_imag = - ::sin(x.real()) * ::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::cos(x); + float expected_real = ::cos(x.real()) * ::cosh(x.imag()); + float expected_imag = -::sin(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::sin(x); - float expected_real = std::sin(x.real()) * std::cosh(x.imag()); - float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::sin(x); + float expected_real = std::sin(x.real()) * std::cosh(x.imag()); + float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::sin(x); - float expected_real = ::sin(x.real()) * ::cosh(x.imag()); - float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::sin(x); + float expected_real = ::sin(x.real()) * ::cosh(x.imag()); + float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::cos(x); - float expected_real = std::cos(x.real()) * std::cosh(x.imag()); - float expected_imag = - std::sin(x.real()) * std::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::cos(x); + float expected_real = std::cos(x.real()) * std::cosh(x.imag()); + float expected_imag = -std::sin(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::cos(x); - float expected_real = ::cos(x.real()) * ::cosh(x.imag()); - float expected_imag = - ::sin(x.real()) * ::sinh(x.imag()); - C10_ASSERT_NEAR(y.real(), expected_real, tol); - C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::cos(x); + float expected_real = ::cos(x.real()) * ::cosh(x.imag()); + float expected_imag = -::sin(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); } } C10_DEFINE_TEST(TestTan, Identity) { // tan(x) = sin(x) / cos(x) { - c10::complex x(0.1, 1.2); - c10::complex y = std::tan(x); - c10::complex z = std::sin(x) / std::cos(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::tan(x); + c10::complex z = std::sin(x) / std::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::tan(x); - c10::complex z = ::sin(x) / ::cos(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::tan(x); + c10::complex z = ::sin(x) / ::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::tan(x); - c10::complex z = std::sin(x) / std::cos(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::tan(x); + c10::complex z = std::sin(x) / std::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::tan(x); - c10::complex z = ::sin(x) / ::cos(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::tan(x); + c10::complex z = ::sin(x) / ::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } } C10_DEFINE_TEST(TestTanh, Identity) { // tanh(x) = sinh(x) / cosh(x) { - c10::complex x(0.1, 1.2); - c10::complex y = std::tanh(x); - c10::complex z = std::sinh(x) / std::cosh(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::tanh(x); + c10::complex z = std::sinh(x) / std::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::tanh(x); - c10::complex z = ::sinh(x) / ::cosh(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::tanh(x); + c10::complex z = ::sinh(x) / ::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = std::tanh(x); - c10::complex z = std::sinh(x) / std::cosh(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = std::tanh(x); + c10::complex z = std::sinh(x) / std::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } { - c10::complex x(0.1, 1.2); - c10::complex y = ::tanh(x); - c10::complex z = ::sinh(x) / ::cosh(x); - C10_ASSERT_NEAR(y.real(), z.real(), tol); - C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + c10::complex x(0.1, 1.2); + c10::complex y = ::tanh(x); + c10::complex z = ::sinh(x) / ::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); } } @@ -373,64 +374,64 @@ C10_DEFINE_TEST(TestRevTrigonometric, Rev) { // acos(cos(x)) = x // atan(tan(x)) = x { - c10::complex x(0.5, 0.6); - c10::complex s = std::sin(x); - c10::complex ss = std::asin(s); - c10::complex c = std::cos(x); - c10::complex cc = std::acos(c); - c10::complex t = std::tan(x); - c10::complex tt = std::atan(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = std::sin(x); + c10::complex ss = std::asin(s); + c10::complex c = std::cos(x); + c10::complex cc = std::acos(c); + c10::complex t = std::tan(x); + c10::complex tt = std::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = ::sin(x); - c10::complex ss = ::asin(s); - c10::complex c = ::cos(x); - c10::complex cc = ::acos(c); - c10::complex t = ::tan(x); - c10::complex tt = ::atan(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = ::sin(x); + c10::complex ss = ::asin(s); + c10::complex c = ::cos(x); + c10::complex cc = ::acos(c); + c10::complex t = ::tan(x); + c10::complex tt = ::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = std::sin(x); - c10::complex ss = std::asin(s); - c10::complex c = std::cos(x); - c10::complex cc = std::acos(c); - c10::complex t = std::tan(x); - c10::complex tt = std::atan(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = std::sin(x); + c10::complex ss = std::asin(s); + c10::complex c = std::cos(x); + c10::complex cc = std::acos(c); + c10::complex t = std::tan(x); + c10::complex tt = std::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = ::sin(x); - c10::complex ss = ::asin(s); - c10::complex c = ::cos(x); - c10::complex cc = ::acos(c); - c10::complex t = ::tan(x); - c10::complex tt = ::atan(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = ::sin(x); + c10::complex ss = ::asin(s); + c10::complex c = ::cos(x); + c10::complex cc = ::acos(c); + c10::complex t = ::tan(x); + c10::complex tt = ::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } } @@ -441,63 +442,63 @@ C10_DEFINE_TEST(TestRevHyperbolic, Rev) { // acosh(cosh(x)) = x // atanh(tanh(x)) = x { - c10::complex x(0.5, 0.6); - c10::complex s = std::sinh(x); - c10::complex ss = std::asinh(s); - c10::complex c = std::cosh(x); - c10::complex cc = std::acosh(c); - c10::complex t = std::tanh(x); - c10::complex tt = std::atanh(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = std::sinh(x); + c10::complex ss = std::asinh(s); + c10::complex c = std::cosh(x); + c10::complex cc = std::acosh(c); + c10::complex t = std::tanh(x); + c10::complex tt = std::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = ::sinh(x); - c10::complex ss = ::asinh(s); - c10::complex c = ::cosh(x); - c10::complex cc = ::acosh(c); - c10::complex t = ::tanh(x); - c10::complex tt = ::atanh(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = ::sinh(x); + c10::complex ss = ::asinh(s); + c10::complex c = ::cosh(x); + c10::complex cc = ::acosh(c); + c10::complex t = ::tanh(x); + c10::complex tt = ::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = std::sinh(x); - c10::complex ss = std::asinh(s); - c10::complex c = std::cosh(x); - c10::complex cc = std::acosh(c); - c10::complex t = std::tanh(x); - c10::complex tt = std::atanh(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = std::sinh(x); + c10::complex ss = std::asinh(s); + c10::complex c = std::cosh(x); + c10::complex cc = std::acosh(c); + c10::complex t = std::tanh(x); + c10::complex tt = std::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } { - c10::complex x(0.5, 0.6); - c10::complex s = ::sinh(x); - c10::complex ss = ::asinh(s); - c10::complex c = ::cosh(x); - c10::complex cc = ::acosh(c); - c10::complex t = ::tanh(x); - c10::complex tt = ::atanh(t); - C10_ASSERT_NEAR(x.real(), ss.real(), tol); - C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); - C10_ASSERT_NEAR(x.real(), cc.real(), tol); - C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); - C10_ASSERT_NEAR(x.real(), tt.real(), tol); - C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + c10::complex x(0.5, 0.6); + c10::complex s = ::sinh(x); + c10::complex ss = ::asinh(s); + c10::complex c = ::cosh(x); + c10::complex cc = ::acosh(c); + c10::complex t = ::tanh(x); + c10::complex tt = ::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); } } diff --git a/c10/test/util/complex_test_common.h b/c10/test/util/complex_test_common.h index 027143206b0..f7ab797e776 100644 --- a/c10/test/util/complex_test_common.h +++ b/c10/test/util/complex_test_common.h @@ -1,10 +1,10 @@ -#include -#include -#include -#include #include +#include #include #include +#include +#include +#include #include #if (defined(__CUDACC__) || defined(__HIPCC__)) @@ -34,71 +34,72 @@ MAYBE_GLOBAL void test_pod() { TEST(TestMemory, ReinterpretCast) { { - std::complex z(1, 2); - c10::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), float(1)); - ASSERT_EQ(zz.imag(), float(2)); + std::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(1)); + ASSERT_EQ(zz.imag(), float(2)); } { - c10::complex z(3, 4); - std::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), float(3)); - ASSERT_EQ(zz.imag(), float(4)); + c10::complex z(3, 4); + std::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(3)); + ASSERT_EQ(zz.imag(), float(4)); } { - std::complex z(1, 2); - c10::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), double(1)); - ASSERT_EQ(zz.imag(), double(2)); + std::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(1)); + ASSERT_EQ(zz.imag(), double(2)); } { - c10::complex z(3, 4); - std::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), double(3)); - ASSERT_EQ(zz.imag(), double(4)); + c10::complex z(3, 4); + std::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(3)); + ASSERT_EQ(zz.imag(), double(4)); } } #if defined(__CUDACC__) || defined(__HIPCC__) TEST(TestMemory, ThrustReinterpretCast) { { - thrust::complex z(1, 2); - c10::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), float(1)); - ASSERT_EQ(zz.imag(), float(2)); + thrust::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(1)); + ASSERT_EQ(zz.imag(), float(2)); } { - c10::complex z(3, 4); - thrust::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), float(3)); - ASSERT_EQ(zz.imag(), float(4)); + c10::complex z(3, 4); + thrust::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(3)); + ASSERT_EQ(zz.imag(), float(4)); } { - thrust::complex z(1, 2); - c10::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), double(1)); - ASSERT_EQ(zz.imag(), double(2)); + thrust::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(1)); + ASSERT_EQ(zz.imag(), double(2)); } { - c10::complex z(3, 4); - thrust::complex zz = *reinterpret_cast*>(&z); - ASSERT_EQ(zz.real(), double(3)); - ASSERT_EQ(zz.imag(), double(4)); + c10::complex z(3, 4); + thrust::complex zz = + *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(3)); + ASSERT_EQ(zz.imag(), double(4)); } } #endif -} // memory +} // namespace memory namespace constructors { -template +template C10_HOST_DEVICE void test_construct_from_scalar() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); @@ -111,29 +112,48 @@ C10_HOST_DEVICE void test_construct_from_scalar() { static_assert(c10::complex().imag() == zero, ""); } -template +template C10_HOST_DEVICE void test_construct_from_other() { constexpr other_t num1 = other_t(1.23); constexpr other_t num2 = other_t(4.56); constexpr scalar_t num3 = scalar_t(num1); constexpr scalar_t num4 = scalar_t(num2); - static_assert(c10::complex(c10::complex(num1, num2)).real() == num3, ""); - static_assert(c10::complex(c10::complex(num1, num2)).imag() == num4, ""); + static_assert( + c10::complex(c10::complex(num1, num2)).real() == num3, + ""); + static_assert( + c10::complex(c10::complex(num1, num2)).imag() == num4, + ""); } MAYBE_GLOBAL void test_convert_constructors() { test_construct_from_scalar(); test_construct_from_scalar(); - static_assert(std::is_convertible, c10::complex>::value, ""); - static_assert(!std::is_convertible, c10::complex>::value, ""); - static_assert(std::is_convertible, c10::complex>::value, ""); - static_assert(std::is_convertible, c10::complex>::value, ""); + static_assert( + std::is_convertible, c10::complex>::value, ""); + static_assert( + !std::is_convertible, c10::complex>::value, + ""); + static_assert( + std::is_convertible, c10::complex>::value, + ""); + static_assert( + std::is_convertible, c10::complex>::value, + ""); - static_assert(std::is_constructible, c10::complex>::value, ""); - static_assert(std::is_constructible, c10::complex>::value, ""); - static_assert(std::is_constructible, c10::complex>::value, ""); - static_assert(std::is_constructible, c10::complex>::value, ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); test_construct_from_other(); test_construct_from_other(); @@ -141,12 +161,16 @@ MAYBE_GLOBAL void test_convert_constructors() { test_construct_from_other(); } -template +template C10_HOST_DEVICE void test_construct_from_std() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); - static_assert(c10::complex(std::complex(num1, num2)).real() == num1, ""); - static_assert(c10::complex(std::complex(num1, num2)).imag() == num2, ""); + static_assert( + c10::complex(std::complex(num1, num2)).real() == num1, + ""); + static_assert( + c10::complex(std::complex(num1, num2)).imag() == num2, + ""); } MAYBE_GLOBAL void test_std_conversion() { @@ -155,12 +179,16 @@ MAYBE_GLOBAL void test_std_conversion() { } #if defined(__CUDACC__) || defined(__HIPCC__) -template +template void test_construct_from_thrust() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); - ASSERT_EQ(c10::complex(thrust::complex(num1, num2)).real(), num1); - ASSERT_EQ(c10::complex(thrust::complex(num1, num2)).imag(), num2); + ASSERT_EQ( + c10::complex(thrust::complex(num1, num2)).real(), + num1); + ASSERT_EQ( + c10::complex(thrust::complex(num1, num2)).imag(), + num2); } TEST(TestConstructors, FromThrust) { @@ -170,7 +198,11 @@ TEST(TestConstructors, FromThrust) { #endif TEST(TestConstructors, UnorderedMap) { - std::unordered_map, c10::complex, c10::hash>> m; + std::unordered_map< + c10::complex, + c10::complex, + c10::hash>> + m; auto key1 = c10::complex(2.5, 3); auto key2 = c10::complex(2, 0); auto val1 = c10::complex(2, -3.2); @@ -181,11 +213,11 @@ TEST(TestConstructors, UnorderedMap) { ASSERT_EQ(m[key2], val2); } -} // constructors +} // namespace constructors namespace assignment { -template +template constexpr c10::complex one() { c10::complex result(3, 4); result = scalar_t(1); @@ -232,7 +264,8 @@ MAYBE_GLOBAL void test_assign_std() { } #if defined(__CUDACC__) || defined(__HIPCC__) -C10_HOST_DEVICE std::tuple, c10::complex> one_two_thrust() { +C10_HOST_DEVICE std::tuple, c10::complex> +one_two_thrust() { thrust::complex src(1, 2); c10::complex ret0; c10::complex ret1; @@ -258,7 +291,8 @@ MAYBE_GLOBAL void test_complex_literals() { static_assert(std::is_same>::value, ""); static_assert((0.5_if).real() == float(), ""); static_assert((0.5_if).imag() == float(0.5), ""); - static_assert(std::is_same>::value, ""); + static_assert( + std::is_same>::value, ""); static_assert((0.5_id).real() == float(), ""); static_assert((0.5_id).imag() == float(0.5), ""); @@ -274,14 +308,14 @@ MAYBE_GLOBAL void test_complex_literals() { namespace real_imag { -template +template constexpr c10::complex zero_one() { c10::complex result; result.imag(scalar_t(1)); return result; } -template +template constexpr c10::complex one_zero() { c10::complex result; result.real(scalar_t(1)); @@ -304,35 +338,35 @@ MAYBE_GLOBAL void test_real_imag_modify() { namespace arithmetic_assign { -template +template constexpr c10::complex p(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result += value; return result; } -template +template constexpr c10::complex m(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result -= value; return result; } -template +template constexpr c10::complex t(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result *= value; return result; } -template +template constexpr c10::complex d(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result /= value; return result; } -template +template C10_HOST_DEVICE void test_arithmetic_assign_scalar() { constexpr c10::complex x = p(scalar_t(1)); static_assert(x.real() == scalar_t(3), ""); @@ -348,35 +382,47 @@ C10_HOST_DEVICE void test_arithmetic_assign_scalar() { static_assert(t.imag() == scalar_t(1), ""); } -template -constexpr c10::complex p(scalar_t real, scalar_t imag, c10::complex rhs) { +template +constexpr c10::complex p( + scalar_t real, + scalar_t imag, + c10::complex rhs) { c10::complex result(real, imag); result += rhs; return result; } -template -constexpr c10::complex m(scalar_t real, scalar_t imag, c10::complex rhs) { +template +constexpr c10::complex m( + scalar_t real, + scalar_t imag, + c10::complex rhs) { c10::complex result(real, imag); result -= rhs; return result; } -template -constexpr c10::complex t(scalar_t real, scalar_t imag, c10::complex rhs) { +template +constexpr c10::complex t( + scalar_t real, + scalar_t imag, + c10::complex rhs) { c10::complex result(real, imag); result *= rhs; return result; } -template -constexpr c10::complex d(scalar_t real, scalar_t imag, c10::complex rhs) { +template +constexpr c10::complex d( + scalar_t real, + scalar_t imag, + c10::complex rhs) { c10::complex result(real, imag); result /= rhs; return result; } -template +template C10_HOST_DEVICE void test_arithmetic_assign_complex() { using namespace c10::complex_literals; constexpr c10::complex x2 = p(scalar_t(2), scalar_t(2), 1.0_if); @@ -429,26 +475,64 @@ MAYBE_GLOBAL void test_arithmetic_assign() { namespace arithmetic { -template +template C10_HOST_DEVICE void test_arithmetic_() { - static_assert(c10::complex(1, 2) == +c10::complex(1, 2), ""); - static_assert(c10::complex(-1, -2) == -c10::complex(1, 2), ""); + static_assert( + c10::complex(1, 2) == +c10::complex(1, 2), ""); + static_assert( + c10::complex(-1, -2) == -c10::complex(1, 2), ""); - static_assert(c10::complex(1, 2) + c10::complex(3, 4) == c10::complex(4, 6), ""); - static_assert(c10::complex(1, 2) + scalar_t(3) == c10::complex(4, 2), ""); - static_assert(scalar_t(3) + c10::complex(1, 2) == c10::complex(4, 2), ""); + static_assert( + c10::complex(1, 2) + c10::complex(3, 4) == + c10::complex(4, 6), + ""); + static_assert( + c10::complex(1, 2) + scalar_t(3) == + c10::complex(4, 2), + ""); + static_assert( + scalar_t(3) + c10::complex(1, 2) == + c10::complex(4, 2), + ""); - static_assert(c10::complex(1, 2) - c10::complex(3, 4) == c10::complex(-2, -2), ""); - static_assert(c10::complex(1, 2) - scalar_t(3) == c10::complex(-2, 2), ""); - static_assert(scalar_t(3) - c10::complex(1, 2) == c10::complex(2, -2), ""); + static_assert( + c10::complex(1, 2) - c10::complex(3, 4) == + c10::complex(-2, -2), + ""); + static_assert( + c10::complex(1, 2) - scalar_t(3) == + c10::complex(-2, 2), + ""); + static_assert( + scalar_t(3) - c10::complex(1, 2) == + c10::complex(2, -2), + ""); - static_assert(c10::complex(1, 2) * c10::complex(3, 4) == c10::complex(-5, 10), ""); - static_assert(c10::complex(1, 2) * scalar_t(3) == c10::complex(3, 6), ""); - static_assert(scalar_t(3) * c10::complex(1, 2) == c10::complex(3, 6), ""); + static_assert( + c10::complex(1, 2) * c10::complex(3, 4) == + c10::complex(-5, 10), + ""); + static_assert( + c10::complex(1, 2) * scalar_t(3) == + c10::complex(3, 6), + ""); + static_assert( + scalar_t(3) * c10::complex(1, 2) == + c10::complex(3, 6), + ""); - static_assert(c10::complex(-5, 10) / c10::complex(3, 4) == c10::complex(1, 2), ""); - static_assert(c10::complex(5, 10) / scalar_t(5) == c10::complex(1, 2), ""); - static_assert(scalar_t(25) / c10::complex(3, 4) == c10::complex(3, -4), ""); + static_assert( + c10::complex(-5, 10) / c10::complex(3, 4) == + c10::complex(1, 2), + ""); + static_assert( + c10::complex(5, 10) / scalar_t(5) == + c10::complex(1, 2), + ""); + static_assert( + scalar_t(25) / c10::complex(3, 4) == + c10::complex(3, -4), + ""); } MAYBE_GLOBAL void test_arithmetic() { @@ -456,7 +540,7 @@ MAYBE_GLOBAL void test_arithmetic() { test_arithmetic_(); } -template +template void test_binary_ops_for_int_type_(T real, T img, int_t num) { c10::complex c(real, img); ASSERT_EQ(c + num, c10::complex(real + num, img)); @@ -466,10 +550,12 @@ void test_binary_ops_for_int_type_(T real, T img, int_t num) { ASSERT_EQ(c * num, c10::complex(real * num, img * num)); ASSERT_EQ(num * c, c10::complex(num * real, num * img)); ASSERT_EQ(c / num, c10::complex(real / num, img / num)); - ASSERT_EQ(num / c, c10::complex(num * real / std::norm(c), -num * img / std::norm(c))); + ASSERT_EQ( + num / c, + c10::complex(num * real / std::norm(c), -num * img / std::norm(c))); } -template +template void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) { test_binary_ops_for_int_type_(real, img, i); test_binary_ops_for_int_type_(real, img, i); @@ -486,12 +572,14 @@ TEST(TestArithmeticIntScalar, All) { namespace equality { -template +template C10_HOST_DEVICE void test_equality_() { - static_assert(c10::complex(1, 2) == c10::complex(1, 2), ""); + static_assert( + c10::complex(1, 2) == c10::complex(1, 2), ""); static_assert(c10::complex(1, 0) == scalar_t(1), ""); static_assert(scalar_t(1) == c10::complex(1, 0), ""); - static_assert(c10::complex(1, 2) != c10::complex(3, 4), ""); + static_assert( + c10::complex(1, 2) != c10::complex(3, 4), ""); static_assert(c10::complex(1, 2) != scalar_t(1), ""); static_assert(scalar_t(1) != c10::complex(1, 2), ""); } @@ -505,7 +593,7 @@ MAYBE_GLOBAL void test_equality() { namespace io { -template +template void test_io_() { std::stringstream ss; c10::complex a(1, 2); @@ -525,14 +613,16 @@ TEST(TestIO, All) { namespace test_std { -template +template C10_HOST_DEVICE void test_callable_() { static_assert(std::real(c10::complex(1, 2)) == scalar_t(1), ""); static_assert(std::imag(c10::complex(1, 2)) == scalar_t(2), ""); std::abs(c10::complex(1, 2)); std::arg(c10::complex(1, 2)); static_assert(std::norm(c10::complex(3, 4)) == scalar_t(25), ""); - static_assert(std::conj(c10::complex(3, 4)) == c10::complex(3, -4), ""); + static_assert( + std::conj(c10::complex(3, 4)) == c10::complex(3, -4), + ""); c10::polar(float(1), float(PI / 2)); c10::polar(double(1), double(PI / 2)); } @@ -542,11 +632,15 @@ MAYBE_GLOBAL void test_callable() { test_callable_(); } -template +template void test_values_() { ASSERT_EQ(std::abs(c10::complex(3, 4)), scalar_t(5)); ASSERT_LT(std::abs(std::arg(c10::complex(0, 1)) - PI / 2), 1e-6); - ASSERT_LT(std::abs(c10::polar(scalar_t(1), scalar_t(PI / 2)) - c10::complex(0, 1)), 1e-6); + ASSERT_LT( + std::abs( + c10::polar(scalar_t(1), scalar_t(PI / 2)) - + c10::complex(0, 1)), + 1e-6); } TEST(TestStd, BasicFunctions) { @@ -554,8 +648,11 @@ TEST(TestStd, BasicFunctions) { test_values_(); // CSQRT edge cases: checks for overflows which are likely to occur // if square root is computed using polar form - ASSERT_LT(std::abs(std::sqrt(c10::complex(-1e20, -4988429.2)).real()), 3e-4); - ASSERT_LT(std::abs(std::sqrt(c10::complex(-1e60, -4988429.2)).real()), 3e-4); + ASSERT_LT( + std::abs(std::sqrt(c10::complex(-1e20, -4988429.2)).real()), 3e-4); + ASSERT_LT( + std::abs(std::sqrt(c10::complex(-1e60, -4988429.2)).real()), + 3e-4); } } // namespace test_std diff --git a/c10/test/util/either_test.cpp b/c10/test/util/either_test.cpp index e4f963aaa79..a2bb42fc55c 100644 --- a/c10/test/util/either_test.cpp +++ b/c10/test/util/either_test.cpp @@ -1,1124 +1,1086 @@ -// Originally taken from https://raw.githubusercontent.com/cryfs/cryfs/14ad22570ddacef22d5ff139cdff68a54fc8234d/test/cpp-utils/either_test.cpp +// Originally taken from +// https://raw.githubusercontent.com/cryfs/cryfs/14ad22570ddacef22d5ff139cdff68a54fc8234d/test/cpp-utils/either_test.cpp -#include #include -#include +#include #include -#include +#include #include +#include -using std::string; -using std::vector; -using std::pair; -using std::tuple; -using std::ostringstream; using c10::either; using c10::make_left; using c10::make_right; +using std::ostringstream; +using std::pair; +using std::string; +using std::tuple; +using std::vector; namespace { class MovableOnly final { -public: - explicit MovableOnly(int value): _value(value) {} - MovableOnly(const MovableOnly&) = delete; - MovableOnly& operator=(const MovableOnly&) = delete; + public: + explicit MovableOnly(int value) : _value(value) {} + MovableOnly(const MovableOnly&) = delete; + MovableOnly& operator=(const MovableOnly&) = delete; - MovableOnly(MovableOnly&& rhs): _value(rhs._value) { - rhs._value = 0; - } + MovableOnly(MovableOnly&& rhs) : _value(rhs._value) { + rhs._value = 0; + } - MovableOnly& operator=(MovableOnly&& rhs) { - _value = rhs._value; - rhs._value = 0; - return *this; - } + MovableOnly& operator=(MovableOnly&& rhs) { + _value = rhs._value; + rhs._value = 0; + return *this; + } - int value() const { - return _value; - } + int value() const { + return _value; + } -private: - int _value; + private: + int _value; }; bool operator==(const MovableOnly& lhs, const MovableOnly& rhs) { return lhs.value() == rhs.value(); } -template -void test_with_matrix(std::vector)>> setups, std::vector> expectations) { - for (const auto& setup: setups) { - for (const auto& expectation: expectations) { +template +void test_with_matrix( + std::vector)>> setups, + std::vector> expectations) { + for (const auto& setup : setups) { + for (const auto& expectation : expectations) { setup(expectation); } } } -template -std::vector&)>> EXPECT_IS_LEFT(const Left& expected) { +template +std::vector&)>> EXPECT_IS_LEFT( + const Left& expected) { return { - [&] (either& obj) { - EXPECT_TRUE(obj.is_left()); - }, [&] (either& obj) { - EXPECT_FALSE(obj.is_right()); - }, [&] (either& obj) { - EXPECT_EQ(expected, obj.left()); - }, [&] (either& obj) { - EXPECT_EQ(expected, std::move(obj).left()); - }, [&] (either& obj) { - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_ANY_THROW(obj.right()); - }, [&] (either& obj) { - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_ANY_THROW(std::move(obj).right()); - } - }; + [&](either& obj) { EXPECT_TRUE(obj.is_left()); }, + [&](either& obj) { EXPECT_FALSE(obj.is_right()); }, + [&](either& obj) { EXPECT_EQ(expected, obj.left()); }, + [&](either& obj) { + EXPECT_EQ(expected, std::move(obj).left()); + }, + [&](either& obj) { + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_ANY_THROW(obj.right()); + }, + [&](either& obj) { + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_ANY_THROW(std::move(obj).right()); + }}; } -template -std::vector&)>> EXPECT_IS_RIGHT(const Right& expected) { +template +std::vector&)>> EXPECT_IS_RIGHT( + const Right& expected) { return { - [&] (either& obj) { - EXPECT_FALSE(obj.is_left()); - }, [&] (either& obj) { - EXPECT_TRUE(obj.is_right()); - }, [&] (either& obj) { - EXPECT_EQ(expected, obj.right()); - }, [&] (either& obj) { + [&](either& obj) { EXPECT_FALSE(obj.is_left()); }, + [&](either& obj) { EXPECT_TRUE(obj.is_right()); }, + [&](either& obj) { EXPECT_EQ(expected, obj.right()); }, + [&](either& obj) { EXPECT_EQ(expected, std::move(obj).right()); - }, [&] (either& obj) { + }, + [&](either& obj) { // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) EXPECT_ANY_THROW(obj.left()); - }, [&] (either& obj) { + }, + [&](either& obj) { // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) EXPECT_ANY_THROW(std::move(obj).left()); - } - }; + }}; } -template +template std::vector> EXPECT_IS(const Value& v) { - return { - [&] (Value& obj) { - return obj == v; - } - }; + return {[&](Value& obj) { return obj == v; }}; } -template +template struct StoreWith1ByteFlag { T val; char flag; }; -template +template void TestSpaceUsage() { - EXPECT_EQ(std::max(sizeof(StoreWith1ByteFlag), sizeof(StoreWith1ByteFlag)), sizeof(either)); -} + EXPECT_EQ( + std::max( + sizeof(StoreWith1ByteFlag), sizeof(StoreWith1ByteFlag)), + sizeof(either)); } +} // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, SpaceUsage) { - TestSpaceUsage(); - TestSpaceUsage(); - TestSpaceUsage(); - TestSpaceUsage(); - TestSpaceUsage>(); + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage>(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeft) { - test_with_matrix({ - [] (std::function&)> test) { - either a(4); - test(a); - }, [] (std::function&)> test) { - either a = 4; - test(a); + test_with_matrix( + { + [](std::function&)> test) { + either a(4); + test(a); + }, + [](std::function&)> test) { + either a = 4; + test(a); + }, }, - }, - EXPECT_IS_LEFT(4) - ); + EXPECT_IS_LEFT(4)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRight) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - test(a); - }, [] (std::function&)> test) { - either a = string("4"); - test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + test(a); + }, + [](std::function&)> test) { + either a = string("4"); + test(a); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMakeLeft) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(4); - test(a); - }, [] (std::function&)> test) { - auto a = make_left(4); - test(a); + test_with_matrix( + { + [](std::function&)> test) { + either a = make_left(4); + test(a); + }, + [](std::function&)> test) { + auto a = make_left(4); + test(a); + }, }, - }, - EXPECT_IS_LEFT(4) - ); + EXPECT_IS_LEFT(4)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMakeLeftWithSameType) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(4); - test(a); - }, [] (std::function&)> test) { - auto a = make_left(4); - test(a); + test_with_matrix( + { + [](std::function&)> test) { + either a = make_left(4); + test(a); + }, + [](std::function&)> test) { + auto a = make_left(4); + test(a); + }, }, - }, - EXPECT_IS_LEFT(4) - ); + EXPECT_IS_LEFT(4)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMakeRight) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right("4"); - test(a); - }, [] (std::function&)> test) { - auto a = make_right("4"); - test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_right("4"); + test(a); + }, + [](std::function&)> test) { + auto a = make_right("4"); + test(a); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMakeRightWithSameType) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right("4"); - test(a); - }, [] (std::function&)> test) { - auto a = make_right("4"); - test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_right("4"); + test(a); + }, + [](std::function&)> test) { + auto a = make_right("4"); + test(a); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMovableOnlyMakeLeft) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(3); - test(a); - }, [] (std::function&)> test) { - auto a = make_left(3); - test(a); + test_with_matrix( + { + [](std::function&)> test) { + either a = make_left(3); + test(a); + }, + [](std::function&)> test) { + auto a = make_left(3); + test(a); + }, }, - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMovableOnlyMakeRight) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right(3); - test(a); - }, [] (std::function&)> test) { - auto a = make_right(3); - test(a); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_right(3); + test(a); + }, + [](std::function&)> test) { + auto a = make_right(3); + test(a); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMultiParamMakeLeft) { - test_with_matrix({ - [] (std::function, string>&)> test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - either, string> a = make_left, string>(5, 6); - test(a); - }, [] (std::function, string>&)> test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = make_left, string>(5, 6); - test(a); + test_with_matrix( + { + [](std::function, string>&)> test) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + either, string> a = + make_left, string>(5, 6); + test(a); + }, + [](std::function, string>&)> test) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = make_left, string>(5, 6); + test(a); + }, }, - }, - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - EXPECT_IS_LEFT, string>(pair(5, 6)) - ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + EXPECT_IS_LEFT, string>(pair(5, 6))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenMultiParamMakeRight) { - test_with_matrix({ - [] (std::function>&)> test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - either> a = make_right>(5, 6); - test(a); - }, [] (std::function>&)> test) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = make_right>(5, 6); - test(a); - } - }, - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - EXPECT_IS_RIGHT>(pair(5, 6)) - ); + test_with_matrix( + {[](std::function>&)> test) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + either> a = make_right>(5, 6); + test(a); + }, + [](std::function>&)> test) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = make_right>(5, 6); + test(a); + }}, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + EXPECT_IS_RIGHT>(pair(5, 6))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { string a = "4"; either b(a); test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { + test_with_matrix( + {[](std::function test) { string a = "4"; either b(a); test(a); - } - }, - EXPECT_IS("4") - ); + }}, + EXPECT_IS("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { string a = "4"; either b(a); test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { + test_with_matrix( + {[](std::function test) { string a = "4"; either b(a); test(a); - } - }, - EXPECT_IS("4") - ); + }}, + EXPECT_IS("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { MovableOnly a(3); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { + test_with_matrix( + {[](std::function test) { MovableOnly a(3); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { MovableOnly a(3); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { + test_with_matrix( + {[](std::function test) { MovableOnly a(3); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssignedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - string a = "4"; - either b(2); - b = a; - test(b); - }, [] (std::function&)> test) { - string a = "4"; - either b("2"); - b = a; - test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + string a = "4"; + either b(2); + b = a; + test(b); + }, + [](std::function&)> test) { + string a = "4"; + either b("2"); + b = a; + test(b); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssignedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { - string a = "4"; - either b(2); - b = a; - test(a); - }, [] (std::function test) { - string a = "4"; - either b("2"); - b = a; - test(a); - } - }, - EXPECT_IS("4") - ); + test_with_matrix( + {[](std::function test) { + string a = "4"; + either b(2); + b = a; + test(a); + }, + [](std::function test) { + string a = "4"; + either b("2"); + b = a; + test(a); + }}, + EXPECT_IS("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssignedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - string a = "4"; - either b(2); - b = a; - test(b); - }, [] (std::function&)> test) { - string a = "4"; - either b("2"); - b = a; - test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + string a = "4"; + either b(2); + b = a; + test(b); + }, + [](std::function&)> test) { + string a = "4"; + either b("2"); + b = a; + test(b); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssignedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { - string a = "4"; - either b(2); - b = a; - test(a); - }, [] (std::function test) { - string a = "4"; - either b("2"); - b = a; - test(a); - } - }, - EXPECT_IS("4") - ); + test_with_matrix( + {[](std::function test) { + string a = "4"; + either b(2); + b = a; + test(a); + }, + [](std::function test) { + string a = "4"; + either b("2"); + b = a; + test(a); + }}, + EXPECT_IS("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssignedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - MovableOnly a(3); - either b(2); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - MovableOnly a(3); - either b(MovableOnly(2)); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + MovableOnly a(3); + either b(2); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssignedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { - MovableOnly a(3); - either b("2"); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function test) { - MovableOnly a(3); - either b(MovableOnly(0)); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS(MovableOnly(0)) - ); + test_with_matrix( + {[](std::function test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function test) { + MovableOnly a(3); + either b(MovableOnly(0)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS(MovableOnly(0))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssignedFromValue_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - MovableOnly a(3); - either b("2"); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - MovableOnly a(3); - either b(MovableOnly(2)); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssignedFromValue_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function test) { - MovableOnly a(3); - either b("2"); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function test) { - MovableOnly a(3); - either b(MovableOnly(2)); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + test_with_matrix( + {[](std::function test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructed_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructed_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(a); - } - }, - EXPECT_IS_LEFT("4") - ); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructed_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a = make_left("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyConstructed_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a = make_left("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(a); - } - }, - EXPECT_IS_LEFT("4") - ); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructed_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + }}, + EXPECT_IS_RIGHT("4")); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructed_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructed_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a = make_right("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + }}, + EXPECT_IS_RIGHT("4")); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyConstructed_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a = make_right("4"); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) either b(a); test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructed_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a(MovableOnly(3)); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructed_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a(MovableOnly(3)); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructed_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(MovableOnly(3)); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_left(MovableOnly(3)); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveConstructed_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(MovableOnly(3)); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_left(MovableOnly(3)); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_LEFT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructed_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a(MovableOnly(3)); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructed_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { + test_with_matrix( + {[](std::function&)> test) { either a(MovableOnly(3)); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructed_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right(MovableOnly(3)); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_right(MovableOnly(3)); either b(std::move(a)); test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveConstructed_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right(MovableOnly(3)); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_right(MovableOnly(3)); either b(std::move(a)); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_RIGHT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssigned_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - either b(2); - b = a; - test(b); - }, [] (std::function&)> test) { - either a("4"); - either b("2"); - b = a; - test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(b); + }, + [](std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(b); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssigned_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - either b(2); - b = a; - test(a); - }, [] (std::function&)> test) { - either a("4"); - either b("2"); - b = a; - test(a); - } - }, - EXPECT_IS_LEFT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(a); + }, + [](std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(a); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssigned_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left("4"); - either b = make_right("2"); - b = a; - test(b); - }, [] (std::function&)> test) { - either a = make_left("4"); - either b = make_left("2"); - b = a; - test(b); - } - }, - EXPECT_IS_LEFT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_left("4"); + either b = make_right("2"); + b = a; + test(b); + }, + [](std::function&)> test) { + either a = make_left("4"); + either b = make_left("2"); + b = a; + test(b); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftCopyAssigned_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left("4"); - either b = make_right("2"); - b = a; - test(a); - }, [] (std::function&)> test) { - either a = make_left("4"); - either b = make_left("2"); - b = a; - test(a); - } - }, - EXPECT_IS_LEFT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_left("4"); + either b = make_right("2"); + b = a; + test(a); + }, + [](std::function&)> test) { + either a = make_left("4"); + either b = make_left("2"); + b = a; + test(a); + }}, + EXPECT_IS_LEFT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssigned_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - either b(2); - b = a; - test(b); - }, [] (std::function&)> test) { - either a("4"); - either b("2"); - b = a; - test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(b); + }, + [](std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(b); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssigned_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - either b(2); - b = a; - test(a); - }, [] (std::function&)> test) { - either a("4"); - either b("2"); - b = a; - test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(a); + }, + [](std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(a); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssigned_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right("4"); - either b = make_left("2"); - b = a; - test(b); - }, [] (std::function&)> test) { - either a = make_right("4"); - either b = make_right("2"); - b = a; - test(b); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_right("4"); + either b = make_left("2"); + b = a; + test(b); + }, + [](std::function&)> test) { + either a = make_right("4"); + either b = make_right("2"); + b = a; + test(b); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightCopyAssigned_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right("4"); - either b = make_left("2"); - b = a; - test(a); - }, [] (std::function&)> test) { - either a = make_right("4"); - either b = make_right("2"); - b = a; - test(a); - } - }, - EXPECT_IS_RIGHT("4") - ); + test_with_matrix( + {[](std::function&)> test) { + either a = make_right("4"); + either b = make_left("2"); + b = a; + test(a); + }, + [](std::function&)> test) { + either a = make_right("4"); + either b = make_right("2"); + b = a; + test(a); + }}, + EXPECT_IS_RIGHT("4")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssigned_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(2); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(MovableOnly(2)); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + either a(MovableOnly(3)); + either b(2); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssigned_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(2); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(MovableOnly(2)); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + test_with_matrix( + {[](std::function&)> test) { + either a(MovableOnly(3)); + either b(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_LEFT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssigned_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(3); - either b = make_right(2); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - either a = make_left(3); - either b = make_left(2); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_LEFT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_left(3); + either b = + make_right(2); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + either a = + make_left(3); + either b = + make_left(2); + b = std::move(a); + test(b); + }}, + EXPECT_IS_LEFT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeftMoveAssigned_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_left(3); - either b = make_right(2); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function&)> test) { - either a = make_left(3); - either b = make_left(2); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + test_with_matrix( + {[](std::function&)> test) { + either a = + make_left(3); + either b = + make_right(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function&)> test) { + either a = + make_left(3); + either b = + make_left(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_LEFT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssigned_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a(MovableOnly(3)); - either b("2"); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(MovableOnly(2)); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + either a(MovableOnly(3)); + either b("2"); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssigned_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a(MovableOnly(3)); - either b("2"); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function&)> test) { - either a(MovableOnly(3)); - either b(MovableOnly(2)); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + test_with_matrix( + {[](std::function&)> test) { + either a(MovableOnly(3)); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_RIGHT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssigned_withSameType_thenNewIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right(3); - either b = make_left(2); - b = std::move(a); - test(b); - }, [] (std::function&)> test) { - either a = make_right(3); - either b = make_right(2); - b = std::move(a); - test(b); - } - }, - EXPECT_IS_RIGHT(MovableOnly(3)) - ); + test_with_matrix( + {[](std::function&)> test) { + either a = + make_right(3); + either b = + make_left(2); + b = std::move(a); + test(b); + }, + [](std::function&)> test) { + either a = + make_right(3); + either b = + make_right(2); + b = std::move(a); + test(b); + }}, + EXPECT_IS_RIGHT(MovableOnly(3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRightMoveAssigned_withSameType_thenOldIsCorrect) { - test_with_matrix({ - [] (std::function&)> test) { - either a = make_right(3); - either b = make_left(2); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - }, [] (std::function&)> test) { - either a = make_right(3); - either b = make_right(2); - b = std::move(a); - test(a); // NOLINT(bugprone-use-after-move) - } - }, - EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + test_with_matrix( + {[](std::function&)> test) { + either a = + make_right(3); + either b = + make_left(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, + [](std::function&)> test) { + either a = + make_right(3); + either b = + make_right(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }}, + EXPECT_IS_RIGHT( + MovableOnly(0)) // 0 is moved-from value ); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenLeft_whenModified_thenValueIsChanged) { - test_with_matrix({ - [] (std::function&)> test) { - either a(4); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - a.left() = 5; - test(a); - }, [] (std::function&)> test) { - either a(4); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - a.left() = 5; - test(a); - } - }, - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - EXPECT_IS_LEFT(5) - ); + test_with_matrix( + {[](std::function&)> test) { + either a(4); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + a.left() = 5; + test(a); + }, + [](std::function&)> test) { + either a(4); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + a.left() = 5; + test(a); + }}, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + EXPECT_IS_LEFT(5)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, givenRight_whenModified_thenValueIsChanged) { - test_with_matrix({ - [] (std::function&)> test) { - either a("4"); - a.right() = "5"; - test(a); - }, [] (std::function&)> test) { - either a("4"); - a.right() = "5"; - test(a); - } - }, - EXPECT_IS_RIGHT("5") - ); + test_with_matrix( + {[](std::function&)> test) { + either a("4"); + a.right() = "5"; + test(a); + }, + [](std::function&)> test) { + either a("4"); + a.right() = "5"; + test(a); + }}, + EXPECT_IS_RIGHT("5")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, canEmplaceConstructLeft) { - test_with_matrix({ - [] (std::function, tuple>&)> test) { + test_with_matrix( + {[](std::function, tuple>&)> + test) { either, tuple> a(2, 3); test(a); - } - }, - EXPECT_IS_LEFT, tuple>(tuple(2, 3)) - ); + }}, + EXPECT_IS_LEFT, tuple>( + tuple(2, 3))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest, canEmplaceConstructRight) { - test_with_matrix({ - [] (std::function, tuple>&)> test) { + test_with_matrix( + {[](std::function, tuple>&)> + test) { either, tuple> a(2, "3", 4); test(a); - } - }, - EXPECT_IS_RIGHT, tuple>(tuple(2, "3", 4)) - ); + }}, + EXPECT_IS_RIGHT, tuple>( + tuple(2, "3", 4))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -1195,10 +1157,9 @@ TEST(EitherTest, givenLeftAndRightWithSameType_thenAreUnequal) { EXPECT_TRUE(b != a); } - namespace { class DestructorCallback { -public: + public: MOCK_CONST_METHOD0(call, void()); void EXPECT_CALLED(int times = 1) { @@ -1206,149 +1167,182 @@ public: } }; class ClassWithDestructorCallback { -public: - ClassWithDestructorCallback(const DestructorCallback *destructorCallback) : _destructorCallback(destructorCallback) {} + public: + ClassWithDestructorCallback(const DestructorCallback* destructorCallback) + : _destructorCallback(destructorCallback) {} // NOLINTNEXTLINE(modernize-use-equals-default) - ClassWithDestructorCallback(const ClassWithDestructorCallback &rhs): _destructorCallback(rhs._destructorCallback) {} + ClassWithDestructorCallback(const ClassWithDestructorCallback& rhs) + : _destructorCallback(rhs._destructorCallback) {} ~ClassWithDestructorCallback() { _destructorCallback->call(); } -private: - const DestructorCallback *_destructorCallback; + private: + const DestructorCallback* _destructorCallback; // NOLINTNEXTLINE(modernize-use-equals-delete) - ClassWithDestructorCallback &operator=(const ClassWithDestructorCallback &rhs) = delete; + ClassWithDestructorCallback& operator=( + const ClassWithDestructorCallback& rhs) = delete; }; class OnlyMoveableClassWithDestructorCallback { -public: - OnlyMoveableClassWithDestructorCallback(const DestructorCallback *destructorCallback) : _destructorCallback(destructorCallback) { } - OnlyMoveableClassWithDestructorCallback(OnlyMoveableClassWithDestructorCallback &&source): _destructorCallback(source._destructorCallback) {} + public: + OnlyMoveableClassWithDestructorCallback( + const DestructorCallback* destructorCallback) + : _destructorCallback(destructorCallback) {} + OnlyMoveableClassWithDestructorCallback( + OnlyMoveableClassWithDestructorCallback&& source) + : _destructorCallback(source._destructorCallback) {} ~OnlyMoveableClassWithDestructorCallback() { _destructorCallback->call(); } -private: + private: C10_DISABLE_COPY_AND_ASSIGN(OnlyMoveableClassWithDestructorCallback); - const DestructorCallback *_destructorCallback; + const DestructorCallback* _destructorCallback; }; -} +} // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, LeftDestructorIsCalled) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(2); //Once for the temp object, once when the either class destructs + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 2); // Once for the temp object, once when the either class destructs - ClassWithDestructorCallback temp(&destructorCallback); - either var = temp; + ClassWithDestructorCallback temp(&destructorCallback); + either var = temp; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, RightDestructorIsCalled) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(2); //Once for the temp object, once when the either class destructs + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 2); // Once for the temp object, once when the either class destructs - ClassWithDestructorCallback temp(&destructorCallback); - either var = temp; + ClassWithDestructorCallback temp(&destructorCallback); + either var = temp; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterCopying) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 3); // Once for the temp object, once for var1 and once for var2 - ClassWithDestructorCallback temp(&destructorCallback); - either var1 = temp; - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - either var2 = var1; + ClassWithDestructorCallback temp(&destructorCallback); + either var1 = temp; + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + either var2 = var1; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, RightDestructorIsCalledAfterCopying) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 3); // Once for the temp object, once for var1 and once for var2 - ClassWithDestructorCallback temp(&destructorCallback); - either var1 = temp; - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - either var2 = var1; + ClassWithDestructorCallback temp(&destructorCallback); + either var1 = temp; + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + either var2 = var1; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterMoving) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 3); // Once for the temp object, once for var1 and once for var2 - OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); - either var1 = std::move(temp); - either var2 = std::move(var1); + OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); + either var1 = + std::move(temp); + either var2 = + std::move(var1); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, RightDestructorIsCalledAfterMoving) { - DestructorCallback destructorCallback; - destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED( + 3); // Once for the temp object, once for var1 and once for var2 - OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); - either var1 = std::move(temp); - either var2 = std::move(var1); + OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); + either var1 = + std::move(temp); + either var2 = + std::move(var1); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterAssignment) { - DestructorCallback destructorCallback1; - DestructorCallback destructorCallback2; - destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment - destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED( + 2); // Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED( + 3); // Once for the temp2 object, once in destructor of var2, once in + // destructor of var1 - ClassWithDestructorCallback temp1(&destructorCallback1); - either var1 = temp1; - ClassWithDestructorCallback temp2(&destructorCallback2); - either var2 = temp2; - var1 = var2; + ClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = temp1; + ClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = temp2; + var1 = var2; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, RightDestructorIsCalledAfterAssignment) { - DestructorCallback destructorCallback1; - DestructorCallback destructorCallback2; - destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment - destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED( + 2); // Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED( + 3); // Once for the temp2 object, once in destructor of var2, once in + // destructor of var1 - ClassWithDestructorCallback temp1(&destructorCallback1); - either var1 = temp1; - ClassWithDestructorCallback temp2(&destructorCallback2); - either var2 = temp2; - var1 = var2; + ClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = temp1; + ClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = temp2; + var1 = var2; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterMoveAssignment) { - DestructorCallback destructorCallback1; - DestructorCallback destructorCallback2; - destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment - destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED( + 2); // Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED( + 3); // Once for the temp2 object, once in destructor of var2, once in + // destructor of var1 - OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); - either var1 = std::move(temp1); - OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); - either var2 = std::move(temp2); - var1 = std::move(var2); + OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = + std::move(temp1); + OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = + std::move(temp2); + var1 = std::move(var2); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(EitherTest_Destructor, RightDestructorIsCalledAfterMoveAssignment) { - DestructorCallback destructorCallback1; - DestructorCallback destructorCallback2; - destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment - destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED( + 2); // Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED( + 3); // Once for the temp2 object, once in destructor of var2, once in + // destructor of var1 - OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); - either var1 = std::move(temp1); - OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); - either var2 = std::move(temp2); - var1 = std::move(var2); + OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = + std::move(temp1); + OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = + std::move(temp2); + var1 = std::move(var2); } diff --git a/c10/test/util/exception_test.cpp b/c10/test/util/exception_test.cpp index 04b3c1af9d2..a12b303c565 100644 --- a/c10/test/util/exception_test.cpp +++ b/c10/test/util/exception_test.cpp @@ -9,7 +9,7 @@ bool throw_func() { throw std::runtime_error("I'm throwing..."); } -template +template inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { try { std::forward(functor)(); @@ -18,7 +18,7 @@ inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { return; } ADD_FAILURE() << "Expected to throw exception with message \"" - << expectedMessage << "\" but didn't throw"; + << expectedMessage << "\" but didn't throw"; } } // namespace @@ -43,30 +43,32 @@ TEST(WarningTest, JustPrintWarning) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(ExceptionTest, ErrorFormatting) { - expectThrowsEq([]() { - TORCH_CHECK(false, "This is invalid"); - }, "This is invalid"); + expectThrowsEq( + []() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid"); - expectThrowsEq([]() { - try { - TORCH_CHECK(false, "This is invalid"); - } catch (Error& e) { - TORCH_RETHROW(e, "While checking X"); - } - }, "This is invalid (While checking X)"); + expectThrowsEq( + []() { + try { + TORCH_CHECK(false, "This is invalid"); + } catch (Error& e) { + TORCH_RETHROW(e, "While checking X"); + } + }, + "This is invalid (While checking X)"); - expectThrowsEq([]() { - try { - try { - TORCH_CHECK(false, "This is invalid"); - } catch (Error& e) { - TORCH_RETHROW(e, "While checking X"); - } - } catch (Error& e) { - TORCH_RETHROW(e, "While checking Y"); - } - }, -R"msg(This is invalid + expectThrowsEq( + []() { + try { + try { + TORCH_CHECK(false, "This is invalid"); + } catch (Error& e) { + TORCH_RETHROW(e, "While checking X"); + } + } catch (Error& e) { + TORCH_RETHROW(e, "While checking Y"); + } + }, + R"msg(This is invalid While checking X While checking Y)msg"); } @@ -94,5 +96,6 @@ TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) EXPECT_ANY_THROW(failInternalAssert()); - EXPECT_EQ(assertionArgumentCounter, 2) << "TORCH_INTERNAL_ASSERT called argument twice"; + EXPECT_EQ(assertionArgumentCounter, 2) + << "TORCH_INTERNAL_ASSERT called argument twice"; } diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index a47fdd40d11..a17582f8d7c 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -67,7 +67,8 @@ class ChildDestructableMock final : public DestructableMock { class NullType1 final { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static SomeClass singleton_; -public: + + public: static constexpr SomeClass* singleton() { return &singleton_; } @@ -77,7 +78,8 @@ SomeClass NullType1::singleton_; class NullType2 final { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static SomeClass singleton_; -public: + + public: static constexpr SomeClass* singleton() { return &singleton_; } @@ -934,7 +936,8 @@ TEST(IntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) { bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } @@ -948,7 +951,8 @@ TEST( givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto obj2 = std::move(obj); EXPECT_FALSE(resourcesReleased); @@ -964,7 +968,8 @@ TEST( givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr obj2 = std::move(obj); EXPECT_FALSE(resourcesReleased); @@ -981,7 +986,8 @@ TEST(IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) { bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); @@ -999,7 +1005,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); @@ -1017,7 +1024,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1040,8 +1048,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = - make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = make_intrusive( + &resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1064,7 +1072,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1085,7 +1094,8 @@ TEST( bool dummy = false; bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto obj2 = make_intrusive(&dummy, &dummy); obj2 = std::move(obj); @@ -1103,7 +1113,8 @@ TEST( bool dummy = false; bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto obj2 = make_intrusive(&dummy, &dummy); obj2 = std::move(obj); @@ -1121,7 +1132,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) intrusive_ptr copy = obj; @@ -1142,7 +1154,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = make_intrusive( + &resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj; EXPECT_FALSE(resourcesReleased); @@ -1162,7 +1175,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); intrusive_ptr copy = obj; obj.reset(); EXPECT_FALSE(resourcesReleased); @@ -1179,7 +1193,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = make_intrusive( + &resourcesReleased, &wasDestructed); intrusive_ptr copy = obj; obj.reset(); EXPECT_FALSE(resourcesReleased); @@ -1197,7 +1212,8 @@ TEST( bool wasDestructed = false; bool dummy = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = make_intrusive(&dummy, &dummy); @@ -1220,7 +1236,8 @@ TEST( bool wasDestructed = false; bool dummy = false; { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = make_intrusive( + &resourcesReleased, &wasDestructed); { intrusive_ptr copy = make_intrusive(&dummy, &dummy); @@ -1245,7 +1262,8 @@ TEST( { auto copy = make_intrusive(&dummy, &dummy); { - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); copy = obj; EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); @@ -1267,8 +1285,8 @@ TEST( { auto copy = make_intrusive(&dummy, &dummy); { - auto obj = - make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = make_intrusive( + &resourcesReleased, &wasDestructed); copy = obj; EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); @@ -1287,7 +1305,8 @@ TEST(IntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) { bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; @@ -1305,7 +1324,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; @@ -1323,7 +1343,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1346,8 +1367,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = - make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = make_intrusive( + &resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1370,7 +1391,8 @@ TEST( bool wasDestructed = false; auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; EXPECT_FALSE(resourcesReleased); @@ -1388,7 +1410,8 @@ TEST( TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.reset(); @@ -1402,7 +1425,8 @@ TEST( givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj; obj.reset(); @@ -1420,7 +1444,8 @@ TEST( givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj; copy.reset(); @@ -1438,7 +1463,8 @@ TEST( givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); // NOLINTNEXTLINE(bugprone-use-after-move) @@ -1457,7 +1483,8 @@ TEST( givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); moved.reset(); @@ -1800,9 +1827,7 @@ TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenDoesntCrash) { } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST( - IntrusivePtrTest, - givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) { +TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) { bool resourcesReleased = false; bool wasDestructed = false; { @@ -1840,7 +1865,9 @@ weak_intrusive_ptr make_weak_only(Args&&... args) { auto intrusive = make_intrusive(std::forward(args)...); return weak_intrusive_ptr(intrusive); } -template > +template < + class T, + class NullType = c10::detail::intrusive_target_default_null_type> weak_intrusive_ptr make_invalid_weak() { return weak_intrusive_ptr(intrusive_ptr()); } @@ -1903,9 +1930,7 @@ TEST( } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST( - WeakIntrusivePtrTest, - vector_insert_weak_intrusive) { +TEST(WeakIntrusivePtrTest, vector_insert_weak_intrusive) { std::vector> priorWorks; std::vector> wips; wips.push_back(make_intrusive()); @@ -2139,8 +2164,10 @@ TEST( TEST( WeakIntrusivePtrTest, givenNullPtr_whenMoveAssigningToDifferentNullptr_thenHasNewNullptr) { - weak_intrusive_ptr obj1 = make_invalid_weak(); - weak_intrusive_ptr obj2 = make_invalid_weak(); + weak_intrusive_ptr obj1 = + make_invalid_weak(); + weak_intrusive_ptr obj2 = + make_invalid_weak(); obj2 = std::move(obj1); EXPECT_NE(NullType1::singleton(), NullType2::singleton()); // NOLINTNEXTLINE(bugprone-use-after-move) @@ -2362,8 +2389,10 @@ TEST( TEST( WeakIntrusivePtrTest, givenNullPtr_whenCopyAssigningToDifferentNullptr_thenHasNewNullptr) { - weak_intrusive_ptr obj1 = make_invalid_weak(); - weak_intrusive_ptr obj2 = make_invalid_weak(); + weak_intrusive_ptr obj1 = + make_invalid_weak(); + weak_intrusive_ptr obj2 = + make_invalid_weak(); obj2 = obj1; EXPECT_NE(NullType1::singleton(), NullType2::singleton()); EXPECT_TRUE(obj1.expired()); @@ -2468,7 +2497,8 @@ TEST( TEST( WeakIntrusivePtrTest, givenNullPtr_whenMoveConstructingToDifferentNullptr_thenHasNewNullptr) { - weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj1 = + make_invalid_weak(); weak_intrusive_ptr obj2 = std::move(obj1); EXPECT_NE(NullType1::singleton(), NullType2::singleton()); // NOLINTNEXTLINE(bugprone-use-after-move) @@ -2575,7 +2605,8 @@ TEST( TEST( WeakIntrusivePtrTest, givenNullPtr_whenCopyConstructingToDifferentNullptr_thenHasNewNullptr) { - weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj1 = + make_invalid_weak(); weak_intrusive_ptr obj2 = obj1; EXPECT_NE(NullType1::singleton(), NullType2::singleton()); EXPECT_TRUE(obj1.expired()); @@ -3175,7 +3206,8 @@ TEST( givenPtr_whenLastStrongPointerResets_thenReleasesResources) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.ptr.reset(); @@ -3192,7 +3224,8 @@ TEST( givenPtr_whenDestructedButStillHasStrongPointers_thenDoesntReleaseResources) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_intrusive(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_intrusive(&resourcesReleased, &wasDestructed); EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.weak.reset(); @@ -3208,7 +3241,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) { bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); } @@ -3222,7 +3256,8 @@ TEST( givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto obj2 = std::move(obj); EXPECT_TRUE(resourcesReleased); @@ -3238,7 +3273,8 @@ TEST( givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { weak_intrusive_ptr obj2 = std::move(obj); EXPECT_TRUE(resourcesReleased); @@ -3255,7 +3291,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) { bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); @@ -3273,7 +3310,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); @@ -3291,7 +3329,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); { auto copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3314,8 +3353,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = - make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = make_weak_only( + &resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3338,7 +3377,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3359,7 +3399,8 @@ TEST( bool dummy = false; bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto obj2 = make_weak_only(&dummy, &dummy); obj2 = std::move(obj); @@ -3377,7 +3418,8 @@ TEST( bool dummy = false; bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto obj2 = make_weak_only(&dummy, &dummy); obj2 = std::move(obj); @@ -3395,7 +3437,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) weak_intrusive_ptr copy = obj; @@ -3416,7 +3459,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = make_weak_only( + &resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = obj; EXPECT_TRUE(resourcesReleased); @@ -3436,7 +3480,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); weak_intrusive_ptr copy = obj; obj.reset(); EXPECT_TRUE(resourcesReleased); @@ -3453,7 +3498,8 @@ TEST( bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = make_weak_only( + &resourcesReleased, &wasDestructed); weak_intrusive_ptr copy = obj; obj.reset(); EXPECT_TRUE(resourcesReleased); @@ -3471,7 +3517,8 @@ TEST( bool wasDestructed = false; bool dummy = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = make_weak_only(&dummy, &dummy); @@ -3494,7 +3541,8 @@ TEST( bool wasDestructed = false; bool dummy = false; { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = make_weak_only( + &resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = make_weak_only(&dummy, &dummy); @@ -3519,7 +3567,8 @@ TEST( { auto copy = make_weak_only(&dummy, &dummy); { - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); copy = obj; EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); @@ -3541,8 +3590,8 @@ TEST( { auto copy = make_weak_only(&dummy, &dummy); { - auto obj = - make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = make_weak_only( + &resourcesReleased, &wasDestructed); copy = obj; EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); @@ -3561,7 +3610,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) { bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; @@ -3579,7 +3629,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; @@ -3597,7 +3648,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); { auto copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3620,8 +3672,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = - make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = make_weak_only( + &resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3644,7 +3696,8 @@ TEST( bool wasDestructed = false; auto obj = make_weak_only(&dummy, &dummy); { - auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); { weak_intrusive_ptr copy = obj2; EXPECT_TRUE(resourcesReleased); @@ -3662,7 +3715,8 @@ TEST( TEST(WeakIntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); EXPECT_TRUE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.reset(); @@ -3676,7 +3730,8 @@ TEST( givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto copy = obj; obj.reset(); @@ -3694,7 +3749,8 @@ TEST( givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto copy = obj; copy.reset(); @@ -3712,7 +3768,8 @@ TEST( givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); // NOLINTNEXTLINE(bugprone-use-after-move) @@ -3731,7 +3788,8 @@ TEST( givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) { bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); moved.reset(); diff --git a/c10/test/util/irange_test.cpp b/c10/test/util/irange_test.cpp index 3677760df9e..274aedd2607 100644 --- a/c10/test/util/irange_test.cpp +++ b/c10/test/util/irange_test.cpp @@ -8,58 +8,58 @@ using namespace ::testing; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(irange_test, range_test) { - std::vector test_vec; - for(const auto i : c10::irange(4, 11)){ - test_vec.push_back(i); - } - const std::vector correct = {{4,5,6,7,8,9,10}}; - ASSERT_EQ(test_vec, correct); + std::vector test_vec; + for (const auto i : c10::irange(4, 11)) { + test_vec.push_back(i); + } + const std::vector correct = {{4, 5, 6, 7, 8, 9, 10}}; + ASSERT_EQ(test_vec, correct); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(irange_test, end_test) { - std::vector test_vec; - for(const auto i : c10::irange(5)){ - test_vec.push_back(i); - } - const std::vector correct = {{0, 1, 2, 3, 4}}; - ASSERT_EQ(test_vec, correct); + std::vector test_vec; + for (const auto i : c10::irange(5)) { + test_vec.push_back(i); + } + const std::vector correct = {{0, 1, 2, 3, 4}}; + ASSERT_EQ(test_vec, correct); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(irange_test, neg_range_test) { - std::vector test_vec; - for(const auto i : c10::irange(-2, 3)){ - test_vec.push_back(i); - } - const std::vector correct = {{-2,-1,0,1,2}}; - ASSERT_EQ(test_vec, correct); + std::vector test_vec; + for (const auto i : c10::irange(-2, 3)) { + test_vec.push_back(i); + } + const std::vector correct = {{-2, -1, 0, 1, 2}}; + ASSERT_EQ(test_vec, correct); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(irange, empty_reverse_range_two_inputs){ - std::vector test_vec; - for(const auto i : c10::irange(3, -3)){ - test_vec.push_back(i); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - if(i>20){ //Cap the number of elements we add if something goes wrong - break; - } +TEST(irange, empty_reverse_range_two_inputs) { + std::vector test_vec; + for (const auto i : c10::irange(3, -3)) { + test_vec.push_back(i); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + if (i > 20) { // Cap the number of elements we add if something goes wrong + break; } - const std::vector correct = {}; - ASSERT_EQ(test_vec, correct); + } + const std::vector correct = {}; + ASSERT_EQ(test_vec, correct); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(irange, empty_reverse_range_one_input){ - std::vector test_vec; - for(const auto i : c10::irange(-3)){ - test_vec.push_back(i); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - if(i>20){ //Cap the number of elements we add if something goes wrong - break; - } +TEST(irange, empty_reverse_range_one_input) { + std::vector test_vec; + for (const auto i : c10::irange(-3)) { + test_vec.push_back(i); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + if (i > 20) { // Cap the number of elements we add if something goes wrong + break; } - const std::vector correct = {}; - ASSERT_EQ(test_vec, correct); + } + const std::vector correct = {}; + ASSERT_EQ(test_vec, correct); } diff --git a/c10/test/util/logging_test.cpp b/c10/test/util/logging_test.cpp index 5822eabbece..a66788012e0 100644 --- a/c10/test/util/logging_test.cpp +++ b/c10/test/util/logging_test.cpp @@ -1,8 +1,8 @@ #include -#include #include #include +#include namespace c10_test { @@ -54,15 +54,15 @@ TEST(LoggingTest, TestEnforceEquals) { namespace { struct EnforceEqWithCaller { - void test(const char *x) { + void test(const char* x) { CAFFE_ENFORCE_EQ_WITH_CALLER(1, 1, "variable: ", x, " is a variable"); } }; -} +} // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LoggingTest, TestEnforceMessageVariables) { - const char *const x = "hello"; + const char* const x = "hello"; CAFFE_ENFORCE_EQ(1, 1, "variable: ", x, " is a variable"); EnforceEqWithCaller e; @@ -70,7 +70,9 @@ TEST(LoggingTest, TestEnforceMessageVariables) { } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(LoggingTest, EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) { +TEST( + LoggingTest, + EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) { std::vector x = {1, 2, 3, 4}; // This case is a little tricky. We have a temporary // std::initializer_list to which our temporary ArrayRef @@ -102,7 +104,7 @@ std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) { out << "Noncopyable(" << nc.x << ")"; return out; } -} +} // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(LoggingTest, DoesntCopyComparedObjects) { @@ -132,7 +134,8 @@ TEST(LoggingTest, EnforceShowcase) { WRAP_AND_PRINT(CAFFE_ENFORCE_EQ( one * two + three, three * two, "It's a pretty complicated expression")); - WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(std::equal_to(), ==, one * two + three, three * two)); + WRAP_AND_PRINT(CAFFE_ENFORCE_THAT( + std::equal_to(), ==, one * two + three, three * two)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp index 73983e7d32e..6595bf46ae5 100644 --- a/c10/test/util/optional_test.cpp +++ b/c10/test/util/optional_test.cpp @@ -17,32 +17,29 @@ class OptionalTest : public ::testing::Test { template T getSampleValue(); -template<> +template <> bool getSampleValue() { return true; } -template<> +template <> uint64_t getSampleValue() { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) return 42; } -template<> +template <> std::string getSampleValue() { return "hello"; } - using OptionalTypes = ::testing::Types< - // 32-bit scalar optimization. - bool, - // Trivially destructible but not 32-bit scalar. - uint64_t, - // Non-trivial destructor. - std::string - >; - + // 32-bit scalar optimization. + bool, + // Trivially destructible but not 32-bit scalar. + uint64_t, + // Non-trivial destructor. + std::string>; TYPED_TEST_CASE(OptionalTest, OptionalTypes); @@ -71,7 +68,8 @@ TYPED_TEST(OptionalTest, Initialized) { moveAssign = std::move(moveFrom2); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - std::array opts = {&opt, ©, ©Assign, &move, &moveAssign}; + std::array opts = { + &opt, ©, ©Assign, &move, &moveAssign}; for (auto* popt : opts) { auto& opt = *popt; EXPECT_TRUE((bool)opt); diff --git a/c10/test/util/ordered_preserving_dict_test.cpp b/c10/test/util/ordered_preserving_dict_test.cpp index be4dbc64db8..11f0646399b 100644 --- a/c10/test/util/ordered_preserving_dict_test.cpp +++ b/c10/test/util/ordered_preserving_dict_test.cpp @@ -1,6 +1,6 @@ -#include -#include #include +#include +#include #include #include @@ -11,7 +11,8 @@ namespace { #define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2); -using dict_int_int = ska_ordered::order_preserving_flat_hash_map; +using dict_int_int = + ska_ordered::order_preserving_flat_hash_map; dict_int_int test_dict(dict_int_int& dict) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) @@ -20,7 +21,7 @@ dict_int_int test_dict(dict_int_int& dict) { } int64_t i = 0; - for (auto entry: dict) { + for (auto entry : dict) { TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1); ++i; } @@ -28,7 +29,7 @@ dict_int_int test_dict(dict_int_int& dict) { // erase a few entries by themselves // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::unordered_set erase_set = {0, 2, 9, 71}; - for (auto erase: erase_set) { + for (auto erase : erase_set) { dict.erase(erase); } @@ -55,7 +56,7 @@ dict_int_int test_dict(dict_int_int& dict) { } i = 0; - for (auto entry: dict) { + for (auto entry : dict) { TORCH_INTERNAL_ASSERT(order[i] == entry.first); TORCH_INTERNAL_ASSERT(dict[order[i]] == entry.second); TORCH_INTERNAL_ASSERT(entry.second == order[i] + 1); @@ -89,11 +90,11 @@ TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) { TORCH_INTERNAL_ASSERT(dict.begin()->first == 1); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, testRefType) { std::shared_ptr t; - using dict_references = ska_ordered::order_preserving_flat_hash_map>; + using dict_references = ska_ordered:: + order_preserving_flat_hash_map>; dict_references dict; @@ -108,7 +109,6 @@ TEST(OrderedPreservingDictTest, testRefType) { TORCH_INTERNAL_ASSERT(ptr.use_count() == 1); } - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, DictCollisions) { struct BadHash { @@ -171,254 +171,262 @@ TEST(OrderedPreservingDictTest, DictCollisions) { } } - -// Tests taken from https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp +// Tests taken from +// https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_range_insert) { - // insert x values in vector, range insert x-15 values from vector to map, check values - const int nb_values = 1000; - std::vector> values; - for(int i = 0; i < nb_values; i++) { - // NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation) - values.push_back(std::make_pair(i, i+1)); - } + // insert x values in vector, range insert x-15 values from vector to map, + // check values + const int nb_values = 1000; + std::vector> values; + for (int i = 0; i < nb_values; i++) { + // NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation) + values.push_back(std::make_pair(i, i + 1)); + } - dict_int_int map = {{-1, 0}, {-2, 0}}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - map.insert(values.begin() + 10, values.end() - 5); + dict_int_int map = {{-1, 0}, {-2, 0}}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + map.insert(values.begin() + 10, values.end() - 5); - TORCH_INTERNAL_ASSERT(map.size(), 987); + TORCH_INTERNAL_ASSERT(map.size(), 987); - ASSERT_EQUAL_PRIM(map.at(-1), 0); + ASSERT_EQUAL_PRIM(map.at(-1), 0); - ASSERT_EQUAL_PRIM(map.at(-2), 0); + ASSERT_EQUAL_PRIM(map.at(-2), 0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for(int i = 10, j = 2; i < nb_values - 5; i++, j++) { - ASSERT_EQUAL_PRIM(map.at(i), i+1); - } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (int i = 10, j = 2; i < nb_values - 5; i++, j++) { + ASSERT_EQUAL_PRIM(map.at(i), i + 1); + } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_range_erase_all) { - // insert x values, delete all - const std::size_t nb_values = 1000; - dict_int_int map; - for (size_t i = 0; i < nb_values; ++i) { - map[i] = i + 1; - } - auto it = map.erase(map.begin(), map.end()); - ASSERT_TRUE(it == map.end()); - ASSERT_TRUE(map.empty()); + // insert x values, delete all + const std::size_t nb_values = 1000; + dict_int_int map; + for (size_t i = 0; i < nb_values; ++i) { + map[i] = i + 1; + } + auto it = map.erase(map.begin(), map.end()); + ASSERT_TRUE(it == map.end()); + ASSERT_TRUE(map.empty()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_range_erase) { - // insert x values, delete all with iterators except 10 first and 780 last values - using HMap = ska_ordered::order_preserving_flat_hash_map; + // insert x values, delete all with iterators except 10 first and 780 last + // values + using HMap = + ska_ordered::order_preserving_flat_hash_map; - const std::size_t nb_values = 1000; - HMap map; - for (size_t i = 0; i < nb_values; ++i) { - map[c10::guts::to_string(i)] = i; - auto begin = map.begin(); - for (size_t j = 0; j <= i; ++j, begin++) { - TORCH_INTERNAL_ASSERT(begin->second == j); - } + const std::size_t nb_values = 1000; + HMap map; + for (size_t i = 0; i < nb_values; ++i) { + map[c10::guts::to_string(i)] = i; + auto begin = map.begin(); + for (size_t j = 0; j <= i; ++j, begin++) { + TORCH_INTERNAL_ASSERT(begin->second == j); } + } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto it_first = std::next(map.begin(), 10); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto it_last = std::next(map.begin(), 220); + + auto it = map.erase(it_first, it_last); + ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780); + ASSERT_EQUAL_PRIM(map.size(), 790); + ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790); + + for (auto& val : map) { + ASSERT_EQUAL_PRIM(map.count(val.first), 1); + } + + // Check order + it = map.begin(); + for (std::size_t i = 0; i < nb_values; i++) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto it_first = std::next(map.begin(), 10); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto it_last = std::next(map.begin(), 220); - - auto it = map.erase(it_first, it_last); - ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780); - ASSERT_EQUAL_PRIM(map.size(), 790); - ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790); - - for(auto& val: map) { - ASSERT_EQUAL_PRIM(map.count(val.first), 1); - } - - // Check order - it = map.begin(); - for(std::size_t i = 0; i < nb_values; i++) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - if(i >= 10 && i < 220) { - continue; - } - auto exp_it = std::pair(c10::guts::to_string(i), i); - TORCH_INTERNAL_ASSERT(*it == exp_it); - ++it; + if (i >= 10 && i < 220) { + continue; } + auto exp_it = + std::pair(c10::guts::to_string(i), i); + TORCH_INTERNAL_ASSERT(*it == exp_it); + ++it; + } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_move_constructor_empty) { - ska_ordered::order_preserving_flat_hash_map map(0); - ska_ordered::order_preserving_flat_hash_map map_move(std::move(map)); + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_move( + std::move(map)); - // NOLINTNEXTLINE(bugprone-use-after-move) - TORCH_INTERNAL_ASSERT(map.empty()); - TORCH_INTERNAL_ASSERT(map_move.empty()); + // NOLINTNEXTLINE(bugprone-use-after-move) + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_move.empty()); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) - TORCH_INTERNAL_ASSERT(map.find("") == map.end()); - TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_move_operator_empty) { - ska_ordered::order_preserving_flat_hash_map map(0); - ska_ordered::order_preserving_flat_hash_map map_move; - map_move = (std::move(map)); + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_move; + map_move = (std::move(map)); - // NOLINTNEXTLINE(bugprone-use-after-move) - TORCH_INTERNAL_ASSERT(map.empty()); - TORCH_INTERNAL_ASSERT(map_move.empty()); + // NOLINTNEXTLINE(bugprone-use-after-move) + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_move.empty()); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) - TORCH_INTERNAL_ASSERT(map.find("") == map.end()); - TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) { - using HMap = ska_ordered::order_preserving_flat_hash_map; + using HMap = + ska_ordered::order_preserving_flat_hash_map; - HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; - HMap map_move(std::move(map)); + HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; + HMap map_move(std::move(map)); - ASSERT_EQUAL_PRIM(map_move.size(), 3); - // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) - ASSERT_EQUAL_PRIM(map.size(), 0); + ASSERT_EQUAL_PRIM(map_move.size(), 3); + // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) + ASSERT_EQUAL_PRIM(map.size(), 0); - map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; - TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); + map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; + TORCH_INTERNAL_ASSERT( + map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) { - using HMap = ska_ordered::order_preserving_flat_hash_map; + using HMap = + ska_ordered::order_preserving_flat_hash_map; - HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; - HMap map_move = std::move(map); + HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; + HMap map_move = std::move(map); - ASSERT_EQUAL_PRIM(map_move.size(), 3); - // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) - ASSERT_EQUAL_PRIM(map.size(), 0); + ASSERT_EQUAL_PRIM(map_move.size(), 3); + // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) + ASSERT_EQUAL_PRIM(map.size(), 0); - map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; - TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); + map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; + TORCH_INTERNAL_ASSERT( + map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) { - using HMap = ska_ordered::order_preserving_flat_hash_map; + using HMap = + ska_ordered::order_preserving_flat_hash_map; + const std::size_t nb_values = 100; + HMap map; + for (size_t i = 0; i < nb_values; ++i) { + map[c10::guts::to_string(i)] = c10::guts::to_string(i); + } - const std::size_t nb_values = 100; - HMap map; - for (size_t i = 0; i < nb_values; ++i) { - map[c10::guts::to_string(i)] = c10::guts::to_string(i); - } + HMap map_copy = map; + HMap map_copy2(map); + HMap map_copy3; + map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0); + map_copy3 = map; - HMap map_copy = map; - HMap map_copy2(map); - HMap map_copy3; - map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0); + TORCH_INTERNAL_ASSERT(map == map_copy); + map.clear(); - map_copy3 = map; - - TORCH_INTERNAL_ASSERT(map == map_copy); - map.clear(); - - TORCH_INTERNAL_ASSERT(map_copy == map_copy2); - TORCH_INTERNAL_ASSERT(map_copy == map_copy3); + TORCH_INTERNAL_ASSERT(map_copy == map_copy2); + TORCH_INTERNAL_ASSERT(map_copy == map_copy3); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_copy_constructor_empty) { - ska_ordered::order_preserving_flat_hash_map map(0); - ska_ordered::order_preserving_flat_hash_map map_copy(map); + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_copy(map); - TORCH_INTERNAL_ASSERT(map.empty()); - TORCH_INTERNAL_ASSERT(map_copy.empty()); + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_copy.empty()); - TORCH_INTERNAL_ASSERT(map.find("") == map.end()); - TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_copy_operator_empty) { - ska_ordered::order_preserving_flat_hash_map map(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map_copy(16); - map_copy = map; + ska_ordered::order_preserving_flat_hash_map map(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map_copy(16); + map_copy = map; - TORCH_INTERNAL_ASSERT(map.empty()); - TORCH_INTERNAL_ASSERT(map_copy.empty()); + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_copy.empty()); - TORCH_INTERNAL_ASSERT(map.find("") == map.end()); - TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); } - /** * at */ // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_at) { - // insert x values, use at for known and unknown values. - const ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + // insert x values, use at for known and unknown values. + const ska_ordered::order_preserving_flat_hash_map + map = {{0, 10}, {-2, 20}}; - ASSERT_EQUAL_PRIM(map.at(0), 10); - ASSERT_EQUAL_PRIM(map.at(-2), 20); - bool thrown = false; - try { - map.at(1); - } catch (...) { - thrown = true; - } - ASSERT_TRUE(thrown); + ASSERT_EQUAL_PRIM(map.at(0), 10); + ASSERT_EQUAL_PRIM(map.at(-2), 20); + bool thrown = false; + try { + map.at(1); + } catch (...) { + thrown = true; + } + ASSERT_TRUE(thrown); } - /** * equal_range */ // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_equal_range) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map = + {{0, 10}, {-2, 20}}; - auto it_pair = map.equal_range(0); - ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1); - ASSERT_EQUAL_PRIM(it_pair.first->second, 10); + auto it_pair = map.equal_range(0); + ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1); + ASSERT_EQUAL_PRIM(it_pair.first->second, 10); - it_pair = map.equal_range(1); - TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second); - TORCH_INTERNAL_ASSERT(it_pair.first == map.end()); + it_pair = map.equal_range(1); + TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second); + TORCH_INTERNAL_ASSERT(it_pair.first == map.end()); } - /** * operator[] */ // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_access_operator) { - // insert x values, use at for known and unknown values. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + // insert x values, use at for known and unknown values. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map = + {{0, 10}, {-2, 20}}; - ASSERT_EQUAL_PRIM(map[0], 10); - ASSERT_EQUAL_PRIM(map[-2], 20); - ASSERT_EQUAL_PRIM(map[2], std::int64_t()); + ASSERT_EQUAL_PRIM(map[0], 10); + ASSERT_EQUAL_PRIM(map[-2], 20); + ASSERT_EQUAL_PRIM(map[2], std::int64_t()); - ASSERT_EQUAL_PRIM(map.size(), 3); + ASSERT_EQUAL_PRIM(map.size(), 3); } /** @@ -426,45 +434,72 @@ TEST(OrderedPreservingDictTest, test_access_operator) { */ // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_swap) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map = {{1, 10}, {8, 80}, {3, 30}}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map2 = {{4, 40}, {5, 50}}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map = + {{1, 10}, {8, 80}, {3, 30}}; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map2 = + {{4, 40}, {5, 50}}; - using std::swap; - swap(map, map2); + using std::swap; + swap(map, map2); - TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{4, 40}, {5, 50}})); - TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}})); + TORCH_INTERNAL_ASSERT( + map == + (ska_ordered::order_preserving_flat_hash_map{ + {4, 40}, {5, 50}})); + TORCH_INTERNAL_ASSERT( + map2 == + (ska_ordered::order_preserving_flat_hash_map{ + {1, 10}, {8, 80}, {3, 30}})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - map.insert({6, 60}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - map2.insert({4, 40}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + map.insert({6, 60}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + map2.insert({4, 40}); - TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{4, 40}, {5, 50}, {6, 60}})); - TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}, {4, 40}})); + TORCH_INTERNAL_ASSERT( + map == + (ska_ordered::order_preserving_flat_hash_map{ + {4, 40}, {5, 50}, {6, 60}})); + TORCH_INTERNAL_ASSERT( + map2 == + (ska_ordered::order_preserving_flat_hash_map{ + {1, 10}, {8, 80}, {3, 30}, {4, 40}})); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(OrderedPreservingDictTest, test_swap_empty) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ska_ordered::order_preserving_flat_hash_map map = {{1, 10}, {8, 80}, {3, 30}}; - ska_ordered::order_preserving_flat_hash_map map2; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ska_ordered::order_preserving_flat_hash_map map = + {{1, 10}, {8, 80}, {3, 30}}; + ska_ordered::order_preserving_flat_hash_map map2; - using std::swap; - swap(map, map2); + using std::swap; + swap(map, map2); - TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{})); - TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}})); + TORCH_INTERNAL_ASSERT( + map == + (ska_ordered:: + order_preserving_flat_hash_map{})); + TORCH_INTERNAL_ASSERT( + map2 == + (ska_ordered::order_preserving_flat_hash_map{ + {1, 10}, {8, 80}, {3, 30}})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - map.insert({6, 60}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - map2.insert({4, 40}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + map.insert({6, 60}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + map2.insert({4, 40}); - TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{6, 60}})); - TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}, {4, 40}})); + TORCH_INTERNAL_ASSERT( + map == + (ska_ordered::order_preserving_flat_hash_map{ + {6, 60}})); + TORCH_INTERNAL_ASSERT( + map2 == + (ska_ordered::order_preserving_flat_hash_map{ + {1, 10}, {8, 80}, {3, 30}, {4, 40}})); } -} +} // namespace diff --git a/c10/test/util/string_view_test.cpp b/c10/test/util/string_view_test.cpp index 907461fc5f1..c5d05d808a8 100644 --- a/c10/test/util/string_view_test.cpp +++ b/c10/test/util/string_view_test.cpp @@ -7,9 +7,9 @@ using c10::string_view; namespace { namespace testutils { constexpr bool string_equal(const char* lhs, const char* rhs, size_t size) { - return (size == 0) - ? true - : (*lhs != *rhs) ? false : string_equal(lhs + 1, rhs + 1, size - 1); + return (size == 0) ? true + : (*lhs != *rhs) ? false + : string_equal(lhs + 1, rhs + 1, size - 1); } static_assert(string_equal("hi", "hi", 2), ""); static_assert(string_equal("", "", 0), ""); @@ -124,7 +124,8 @@ TEST(StringViewTest, testCopyAssignment) { static_assert(5 == (string_view() = "hello").size(), ""); static_assert( // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - string_equal("hello", (string_view() = "hello").data(), 5), ""); + string_equal("hello", (string_view() = "hello").data(), 5), + ""); } #endif const string_view hello = assign("hello"); @@ -233,7 +234,7 @@ static_assert('o' == string_view("hello").back(), ""); namespace test_data { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) static_assert(string_equal("hello", string_view("hello").data(), 5), ""); -} +} // namespace test_data namespace test_size_length { static_assert(0 == string_view("").size(), ""); diff --git a/c10/test/util/tempfile_test.cpp b/c10/test/util/tempfile_test.cpp index 478f6279a1f..3b568460dad 100644 --- a/c10/test/util/tempfile_test.cpp +++ b/c10/test/util/tempfile_test.cpp @@ -1,7 +1,7 @@ #include #include -#include #include +#include #if !defined(_WIN32) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/c10/test/util/typeid_test.cpp b/c10/test/util/typeid_test.cpp index 80399835894..dab9fc67943 100644 --- a/c10/test/util/typeid_test.cpp +++ b/c10/test/util/typeid_test.cpp @@ -8,7 +8,7 @@ namespace { class TypeMetaTestFoo {}; class TypeMetaTestBar {}; -} +} // namespace CAFFE_KNOWN_TYPE(TypeMetaTestFoo); CAFFE_KNOWN_TYPE(TypeMetaTestBar); @@ -73,7 +73,6 @@ TEST(TypeMetaTest, TypeMeta) { EXPECT_NE(bar_meta.name().find("TypeMetaTestBar"), c10::string_view::npos); } - class ClassAllowAssignment { public: // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) @@ -92,7 +91,7 @@ class ClassNoAssignment { ClassNoAssignment& operator=(const ClassNoAssignment& src) = delete; int x; }; -} +} // namespace CAFFE_KNOWN_TYPE(ClassAllowAssignment); CAFFE_KNOWN_TYPE(ClassNoAssignment); @@ -135,5 +134,5 @@ TEST(TypeMetaTest, Float16IsNotUint16) { EXPECT_NE(TypeMeta::Id(), TypeMeta::Id()); } -} // namespace -} // namespace caffe2 +} // namespace +} // namespace caffe2 diff --git a/c10/util/Array.h b/c10/util/Array.h index cc3630cc157..f33a36e8e84 100644 --- a/c10/util/Array.h +++ b/c10/util/Array.h @@ -1,18 +1,20 @@ /** -* This file is based on the std::array implementation of libstdc++ at -* https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html -* -* Changes: -* - isolate, i.e. remove dependencies on internal libstdc++ stuff -* - use c++17 behavior even in c++11 or c++14 -* - remove std::swappable special case because that doesn't work with MSVC -* - constexpr more things -* - add some features like prepend/tail -* -* If using std::array at runtime, feel free to either keep using std::array or use this one - it doesn't really matter. -* For compile time computations, this one here is preferred because std::array in C++11 -* misses some constexpr specifiers, forcing these methods to be called at runtime instead of compile time. -*/ + * This file is based on the std::array implementation of libstdc++ at + * https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html + * + * Changes: + * - isolate, i.e. remove dependencies on internal libstdc++ stuff + * - use c++17 behavior even in c++11 or c++14 + * - remove std::swappable special case because that doesn't work with MSVC + * - constexpr more things + * - add some features like prepend/tail + * + * If using std::array at runtime, feel free to either keep using std::array or + * use this one - it doesn't really matter. For compile time computations, this + * one here is preferred because std::array in C++11 misses some constexpr + * specifiers, forcing these methods to be called at runtime instead of compile + * time. + */ // Copyright (C) 2007-2017 Free Software Foundation, Inc. // @@ -38,16 +40,17 @@ #pragma once -#include #include +#include #include #include #include -namespace c10 { namespace guts { +namespace c10 { +namespace guts { namespace detail { -template +template struct __array_traits final { using _Type = _Tp[_Nm]; @@ -60,7 +63,7 @@ struct __array_traits final { } }; -template +template struct __array_traits<_Tp, 0> final { struct _Type final {}; @@ -76,11 +79,11 @@ struct __array_traits<_Tp, 0> final { [[noreturn]] inline void __throw_out_of_range(std::string msg) { throw std::out_of_range(std::move(msg)); } -} +} // namespace detail -template +template class array final { -public: + public: using value_type = _Tp; using pointer = value_type*; using const_pointer = const value_type*; @@ -93,76 +96,99 @@ public: using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; -private: + private: using _AT_Type = detail::__array_traits<_Tp, _Nm>; -public: // needs to be public member for aggregate initialization + + public: // needs to be public member for aggregate initialization typename _AT_Type::_Type _M_elems; -public: + public: // No explicit construct/copy/destroy for aggregate type. // DR 776. - constexpr void fill(const value_type& __u) - { std::fill_n(begin(), size(), __u); } + constexpr void fill(const value_type& __u) { + std::fill_n(begin(), size(), __u); + } - constexpr void swap(array& __other) - { std::swap_ranges(begin(), end(), __other.begin()); } + constexpr void swap(array& __other) { + std::swap_ranges(begin(), end(), __other.begin()); + } // Iterators. - constexpr iterator begin() noexcept - { return iterator(data()); } + constexpr iterator begin() noexcept { + return iterator(data()); + } - constexpr const_iterator begin() const noexcept - { return const_iterator(data()); } + constexpr const_iterator begin() const noexcept { + return const_iterator(data()); + } - constexpr iterator end() noexcept - { return iterator(data() + _Nm); } + constexpr iterator end() noexcept { + return iterator(data() + _Nm); + } - constexpr const_iterator end() const noexcept - { return const_iterator(data() + _Nm); } + constexpr const_iterator end() const noexcept { + return const_iterator(data() + _Nm); + } - constexpr reverse_iterator rbegin() noexcept - { return reverse_iterator(end()); } + constexpr reverse_iterator rbegin() noexcept { + return reverse_iterator(end()); + } - constexpr const_reverse_iterator rbegin() const noexcept - { return const_reverse_iterator(end()); } + constexpr const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } - constexpr reverse_iterator rend() noexcept - { return reverse_iterator(begin()); } + constexpr reverse_iterator rend() noexcept { + return reverse_iterator(begin()); + } - constexpr const_reverse_iterator rend() const noexcept - { return const_reverse_iterator(begin()); } + constexpr const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } - constexpr const_iterator cbegin() const noexcept - { return const_iterator(data()); } + constexpr const_iterator cbegin() const noexcept { + return const_iterator(data()); + } - constexpr const_iterator cend() const noexcept - { return const_iterator(data() + _Nm); } + constexpr const_iterator cend() const noexcept { + return const_iterator(data() + _Nm); + } - constexpr const_reverse_iterator crbegin() const noexcept - { return const_reverse_iterator(end()); } + constexpr const_reverse_iterator crbegin() const noexcept { + return const_reverse_iterator(end()); + } - constexpr const_reverse_iterator crend() const noexcept - { return const_reverse_iterator(begin()); } + constexpr const_reverse_iterator crend() const noexcept { + return const_reverse_iterator(begin()); + } // Capacity. - constexpr size_type size() const noexcept { return _Nm; } + constexpr size_type size() const noexcept { + return _Nm; + } - constexpr size_type max_size() const noexcept { return _Nm; } + constexpr size_type max_size() const noexcept { + return _Nm; + } - constexpr bool empty() const noexcept { return size() == 0; } + constexpr bool empty() const noexcept { + return size() == 0; + } // Element access. - constexpr reference operator[](size_type __n) noexcept - { return _AT_Type::_S_ref(_M_elems, __n); } + constexpr reference operator[](size_type __n) noexcept { + return _AT_Type::_S_ref(_M_elems, __n); + } - constexpr const_reference operator[](size_type __n) const noexcept - { return _AT_Type::_S_ref(_M_elems, __n); } + constexpr const_reference operator[](size_type __n) const noexcept { + return _AT_Type::_S_ref(_M_elems, __n); + } constexpr reference at(size_type __n) { if (__n >= _Nm) { - detail::__throw_out_of_range(std::string() + - "array::at: __n (which is " + to_string(__n) + ") " + + detail::__throw_out_of_range( + std::string() + "array::at: __n (which is " + to_string(__n) + ") " + ">= _Nm (which is " + to_string(_Nm) + ")"); } return _AT_Type::_S_ref(_M_elems, __n); @@ -171,101 +197,133 @@ public: constexpr const_reference at(size_type __n) const { // Result of conditional expression must be an lvalue so use // boolean ? lvalue : (throw-expr, lvalue) - return __n < _Nm ? _AT_Type::_S_ref(_M_elems, __n) - : (detail::__throw_out_of_range(std::string() + - "array::at: __n (which is " + to_string(__n) + ") " + - ">= _Nm (which is " + to_string(_Nm) + ")"), - _AT_Type::_S_ref(_M_elems, 0)); - } - - constexpr reference front() noexcept - { return *begin(); } - - constexpr const_reference front() const noexcept - { return _AT_Type::_S_ref(_M_elems, 0); } - - constexpr reference back() noexcept - { return _Nm ? *(end() - 1) : *end(); } - - constexpr const_reference back() const noexcept - { - return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1) - : _AT_Type::_S_ref(_M_elems, 0); + return __n < _Nm + ? _AT_Type::_S_ref(_M_elems, __n) + : (detail::__throw_out_of_range( + std::string() + "array::at: __n (which is " + to_string(__n) + + ") " + ">= _Nm (which is " + to_string(_Nm) + ")"), + _AT_Type::_S_ref(_M_elems, 0)); } - constexpr pointer data() noexcept - { return _AT_Type::_S_ptr(_M_elems); } + constexpr reference front() noexcept { + return *begin(); + } - constexpr const_pointer data() const noexcept - { return _AT_Type::_S_ptr(_M_elems); } + constexpr const_reference front() const noexcept { + return _AT_Type::_S_ref(_M_elems, 0); + } + + constexpr reference back() noexcept { + return _Nm ? *(end() - 1) : *end(); + } + + constexpr const_reference back() const noexcept { + return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1) + : _AT_Type::_S_ref(_M_elems, 0); + } + + constexpr pointer data() noexcept { + return _AT_Type::_S_ptr(_M_elems); + } + + constexpr const_pointer data() const noexcept { + return _AT_Type::_S_ptr(_M_elems); + } }; #if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201606 - template - array(_Tp, _Up...) -> - array::value && ...), _Tp>, 1 + sizeof...(_Up)>; +template +array(_Tp, _Up...) -> array< + std::enable_if_t<(std::is_same<_Tp, _Up>::value && ...), _Tp>, + 1 + sizeof...(_Up)>; #endif // Array comparisons. namespace detail { -template -constexpr inline bool array_equals_(const array& lhs, const array& rhs, size_t current_index) { +template +constexpr inline bool array_equals_( + const array& lhs, + const array& rhs, + size_t current_index) { return (current_index == N) - ? true - : (lhs.at(current_index) == rhs.at(current_index) && array_equals_(lhs, rhs, current_index + 1)); + ? true + : (lhs.at(current_index) == rhs.at(current_index) && + array_equals_(lhs, rhs, current_index + 1)); } -template -constexpr inline bool array_less_(const array& lhs, const array& rhs, size_t current_index) { +template +constexpr inline bool array_less_( + const array& lhs, + const array& rhs, + size_t current_index) { return (current_index == N) - ? false - : (lhs.at(current_index) < rhs.at(current_index) || array_less_(lhs, rhs, current_index + 1)); + ? false + : (lhs.at(current_index) < rhs.at(current_index) || + array_less_(lhs, rhs, current_index + 1)); } +} // namespace detail +template +constexpr inline bool operator==( + const array<_Tp, _Nm>& __one, + const array<_Tp, _Nm>& __two) { + return detail::array_equals_(__one, __two, 0); } -template -constexpr inline bool operator==(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two) -{ return detail::array_equals_(__one, __two, 0); } -template -constexpr inline bool operator!=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two) -{ return !(__one == __two); } +template +constexpr inline bool operator!=( + const array<_Tp, _Nm>& __one, + const array<_Tp, _Nm>& __two) { + return !(__one == __two); +} -template -constexpr inline bool operator<(const array<_Tp, _Nm>& __a, const array<_Tp, _Nm>& __b) -{ return detail::array_less_(__a, __b, 0); } +template +constexpr inline bool operator<( + const array<_Tp, _Nm>& __a, + const array<_Tp, _Nm>& __b) { + return detail::array_less_(__a, __b, 0); +} -template -constexpr inline bool operator>(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two) -{ return __two < __one; } +template +constexpr inline bool operator>( + const array<_Tp, _Nm>& __one, + const array<_Tp, _Nm>& __two) { + return __two < __one; +} -template -constexpr inline bool operator<=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two) -{ return !(__one > __two); } +template +constexpr inline bool operator<=( + const array<_Tp, _Nm>& __one, + const array<_Tp, _Nm>& __two) { + return !(__one > __two); +} -template -constexpr inline bool operator>=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two) -{ return !(__one < __two); } +template +constexpr inline bool operator>=( + const array<_Tp, _Nm>& __one, + const array<_Tp, _Nm>& __two) { + return !(__one < __two); +} // Specialized algorithms. -template -inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept(noexcept(__one.swap(__two))) -{ __one.swap(__two); } - -template -constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept { - static_assert(_Int < _Nm, "array index is within bounds"); - return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int); +template +inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept( + noexcept(__one.swap(__two))) { + __one.swap(__two); } -template -constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept -{ +template +constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept { + static_assert(_Int < _Nm, "array index is within bounds"); + return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int); +} + +template +constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept { static_assert(_Int < _Nm, "array index is within bounds"); return std::move(get<_Int>(__arr)); } -template -constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept -{ +template +constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept { static_assert(_Int < _Nm, "array index is within bounds"); return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int); } @@ -278,27 +336,34 @@ constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept * prepend(2, {3, 4}) == {2, 3, 4} */ namespace detail { -template -constexpr inline array tail_(const array& arg, std::index_sequence) { - static_assert(sizeof...(INDEX) == N-1, "invariant"); - return {{get(arg)...}}; +template +constexpr inline array tail_( + const array& arg, + std::index_sequence) { + static_assert(sizeof...(INDEX) == N - 1, "invariant"); + return {{get(arg)...}}; } -} -template -constexpr inline array tail(const array& arg) { - static_assert(N > 0, "Can only call tail() on an array with at least one element"); - return detail::tail_(arg, std::make_index_sequence()); +} // namespace detail +template +constexpr inline array tail(const array& arg) { + static_assert( + N > 0, "Can only call tail() on an array with at least one element"); + return detail::tail_(arg, std::make_index_sequence()); } namespace detail { -template -constexpr inline array prepend_(T&& head, const array& tail, std::index_sequence) { +template +constexpr inline array prepend_( + T&& head, + const array& tail, + std::index_sequence) { return {{std::forward(head), get(tail)...}}; } -} -template -constexpr inline array prepend(T&& head, const array& tail) { - return detail::prepend_(std::forward(head), tail, std::make_index_sequence()); +} // namespace detail +template +constexpr inline array prepend(T&& head, const array& tail) { + return detail::prepend_( + std::forward(head), tail, std::make_index_sequence()); } /** @@ -309,15 +374,18 @@ constexpr inline array prepend(T&& head, const array& tail) { */ namespace detail { -template -constexpr array to_array_(const T (&arr)[N], std::index_sequence) { +template +constexpr array to_array_( + const T (&arr)[N], + std::index_sequence) { return {{arr[INDEX]...}}; } -} +} // namespace detail -template +template constexpr array to_array(const T (&arr)[N]) { return detail::to_array_(arr, std::make_index_sequence()); } -}} +} // namespace guts +} // namespace c10 diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index ee40c572187..d0bc5a207b0 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -15,10 +15,10 @@ #pragma once -#include #include -#include #include +#include +#include #include #include @@ -81,12 +81,15 @@ class ArrayRef final { : Data(Vec.data()), Length(Vec.size()) {} /// Construct an ArrayRef from a std::vector. - // The enable_if stuff here makes sure that this isn't used for std::vector, - // because ArrayRef can't work on a std::vector bitfield. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. template /* implicit */ ArrayRef(const std::vector& Vec) : Data(Vec.data()), Length(Vec.size()) { - static_assert(!std::is_same::value, "ArrayRef cannot be constructed from a std::vector bitfield."); + static_assert( + !std::is_same::value, + "ArrayRef cannot be constructed from a std::vector bitfield."); } /// Construct an ArrayRef from a std::array @@ -100,7 +103,9 @@ class ArrayRef final { /// Construct an ArrayRef from a std::initializer_list. /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) - : Data(std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) : std::begin(Vec)), + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), Length(Vec.size()) {} /// @} @@ -146,7 +151,8 @@ class ArrayRef final { /// front - Get the first element. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const { - TORCH_CHECK(!empty(), "ArrayRef: attempted to access front() of empty list"); + TORCH_CHECK( + !empty(), "ArrayRef: attempted to access front() of empty list"); return Data[0]; } @@ -162,7 +168,8 @@ class ArrayRef final { } /// slice(n, m) - Take M elements of the array starting at element N - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N, size_t M) const { + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N, size_t M) + const { TORCH_CHECK( N + M <= size(), "ArrayRef: invalid slice, N = ", @@ -224,10 +231,10 @@ class ArrayRef final { }; template -std::ostream& operator<<(std::ostream & out, ArrayRef list) { +std::ostream& operator<<(std::ostream& out, ArrayRef list) { int i = 0; out << "["; - for(auto e : list) { + for (auto e : list) { if (i++ > 0) out << ", "; out << e; diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index e683cd9df8e..9dd30d2d5af 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -7,7 +7,8 @@ namespace c10 { /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 x = __bfloat16_as_ushort(__float2bfloat16(value)); #else // RNE by default @@ -37,8 +38,9 @@ inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { #if defined(__CUDACC__) || defined(__HIPCC__) inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __ldg(reinterpret_cast(ptr)); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); #else return *ptr; #endif @@ -47,19 +49,23 @@ inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { /// Arithmetic -inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16& a, const BFloat16& b) { +inline C10_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { return static_cast(a) + static_cast(b); } -inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a, const BFloat16& b) { +inline C10_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { return static_cast(a) - static_cast(b); } -inline C10_HOST_DEVICE BFloat16 operator*(const BFloat16& a, const BFloat16& b) { +inline C10_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { return static_cast(a) * static_cast(b); } -inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) + __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / static_cast(b); } @@ -243,7 +249,7 @@ namespace std { template <> class numeric_limits { -public: + public: static constexpr bool is_signed = true; static constexpr bool is_specialized = true; static constexpr bool is_integer = false; diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h index 73f04869549..ac625bade07 100644 --- a/c10/util/BFloat16-math.h +++ b/c10/util/BFloat16-math.h @@ -6,35 +6,89 @@ namespace std { /// Used by vec256::map -inline c10::BFloat16 acos(c10::BFloat16 a) { return std::acos(float(a));} -inline c10::BFloat16 asin(c10::BFloat16 a) { return std::asin(float(a));} -inline c10::BFloat16 atan(c10::BFloat16 a) { return std::atan(float(a));} -inline c10::BFloat16 erf(c10::BFloat16 a) { return std::erf(float(a));} -inline c10::BFloat16 erfc(c10::BFloat16 a) { return std::erfc(float(a));} -inline c10::BFloat16 exp(c10::BFloat16 a) { return std::exp(float(a));} -inline c10::BFloat16 expm1(c10::BFloat16 a) { return std::expm1(float(a));} -inline c10::BFloat16 log(c10::BFloat16 a) { return std::log(float(a));} -inline c10::BFloat16 log10(c10::BFloat16 a) { return std::log10(float(a));} -inline c10::BFloat16 log1p(c10::BFloat16 a) { return std::log1p(float(a));} -inline c10::BFloat16 log2(c10::BFloat16 a) { return std::log2(float(a));} -inline c10::BFloat16 ceil(c10::BFloat16 a) { return std::ceil(float(a));} -inline c10::BFloat16 cos(c10::BFloat16 a) { return std::cos(float(a));} -inline c10::BFloat16 floor(c10::BFloat16 a) { return std::floor(float(a));} -inline c10::BFloat16 nearbyint(c10::BFloat16 a) { return std::nearbyint(float(a));} -inline c10::BFloat16 sin(c10::BFloat16 a) { return std::sin(float(a));} -inline c10::BFloat16 tan(c10::BFloat16 a) { return std::tan(float(a));} -inline c10::BFloat16 tanh(c10::BFloat16 a) { return std::tanh(float(a));} -inline c10::BFloat16 trunc(c10::BFloat16 a) { return std::trunc(float(a));} -inline c10::BFloat16 lgamma(c10::BFloat16 a) { return std::lgamma(float(a));} -inline c10::BFloat16 sqrt(c10::BFloat16 a) { return std::sqrt(float(a));} -inline c10::BFloat16 rsqrt(c10::BFloat16 a) { return 1.0 / std::sqrt(float(a));} -inline c10::BFloat16 abs(c10::BFloat16 a) { return std::abs(float(a));} +inline c10::BFloat16 acos(c10::BFloat16 a) { + return std::acos(float(a)); +} +inline c10::BFloat16 asin(c10::BFloat16 a) { + return std::asin(float(a)); +} +inline c10::BFloat16 atan(c10::BFloat16 a) { + return std::atan(float(a)); +} +inline c10::BFloat16 erf(c10::BFloat16 a) { + return std::erf(float(a)); +} +inline c10::BFloat16 erfc(c10::BFloat16 a) { + return std::erfc(float(a)); +} +inline c10::BFloat16 exp(c10::BFloat16 a) { + return std::exp(float(a)); +} +inline c10::BFloat16 expm1(c10::BFloat16 a) { + return std::expm1(float(a)); +} +inline c10::BFloat16 log(c10::BFloat16 a) { + return std::log(float(a)); +} +inline c10::BFloat16 log10(c10::BFloat16 a) { + return std::log10(float(a)); +} +inline c10::BFloat16 log1p(c10::BFloat16 a) { + return std::log1p(float(a)); +} +inline c10::BFloat16 log2(c10::BFloat16 a) { + return std::log2(float(a)); +} +inline c10::BFloat16 ceil(c10::BFloat16 a) { + return std::ceil(float(a)); +} +inline c10::BFloat16 cos(c10::BFloat16 a) { + return std::cos(float(a)); +} +inline c10::BFloat16 floor(c10::BFloat16 a) { + return std::floor(float(a)); +} +inline c10::BFloat16 nearbyint(c10::BFloat16 a) { + return std::nearbyint(float(a)); +} +inline c10::BFloat16 sin(c10::BFloat16 a) { + return std::sin(float(a)); +} +inline c10::BFloat16 tan(c10::BFloat16 a) { + return std::tan(float(a)); +} +inline c10::BFloat16 tanh(c10::BFloat16 a) { + return std::tanh(float(a)); +} +inline c10::BFloat16 trunc(c10::BFloat16 a) { + return std::trunc(float(a)); +} +inline c10::BFloat16 lgamma(c10::BFloat16 a) { + return std::lgamma(float(a)); +} +inline c10::BFloat16 sqrt(c10::BFloat16 a) { + return std::sqrt(float(a)); +} +inline c10::BFloat16 rsqrt(c10::BFloat16 a) { + return 1.0 / std::sqrt(float(a)); +} +inline c10::BFloat16 abs(c10::BFloat16 a) { + return std::abs(float(a)); +} #if defined(_MSC_VER) && defined(__CUDACC__) -inline c10::BFloat16 pow(c10::BFloat16 a, double b) { return std::pow(float(a), float(b));} +inline c10::BFloat16 pow(c10::BFloat16 a, double b) { + return std::pow(float(a), float(b)); +} #else -inline c10::BFloat16 pow(c10::BFloat16 a, double b) { return std::pow(float(a), b);} +inline c10::BFloat16 pow(c10::BFloat16 a, double b) { + return std::pow(float(a), b); +} #endif -inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) { return std::pow(float(a), float(b));} -inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) { return std::fmod(float(a), float(b));} +inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) { + return std::pow(float(a), float(b)); +} +inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) { + return std::fmod(float(a), float(b)); +} } // namespace std diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index ce0d229b72c..5446eb94192 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -14,60 +14,60 @@ namespace c10 { namespace detail { - inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { - float res = 0; - uint32_t tmp = src; - tmp <<= 16; +inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; #ifdef __HIP_PLATFORM_HCC__ - float* tempRes; + float* tempRes; - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - tempRes = reinterpret_cast(&tmp); - res = *tempRes; + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; #else - std::memcpy(&res, &tmp, sizeof(tmp)); + std::memcpy(&res, &tmp, sizeof(tmp)); #endif - return res; - } + return res; +} - inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { - uint32_t res = 0; +inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; #ifdef __HIP_PLATFORM_HCC__ - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - uint32_t* tempRes = reinterpret_cast(&src); - res = *tempRes; + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; #else - std::memcpy(&res, &src, sizeof(res)); + std::memcpy(&res, &src, sizeof(res)); #endif - return res >> 16; - } + return res >> 16; +} - inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { #if defined(__HIP_PLATFORM_HCC__) - if(src != src) { + if (src != src) { #elif defined(_MSC_VER) - if (isnan(src)) { + if (isnan(src)) { #else - if (std::isnan(src)) { + if (std::isnan(src)) { #endif - return UINT16_C(0x7FC0); - } else { - union { - uint32_t U32; - float F32; - }; + return UINT16_C(0x7FC0); + } else { + union { + uint32_t U32; + float F32; + }; - F32 = src; - uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - return static_cast((U32 + rounding_bias) >> 16); - } + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); } +} } // namespace detail struct alignas(2) BFloat16 { @@ -85,7 +85,8 @@ struct alignas(2) BFloat16 { return from_bits_t(); } - constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){}; + constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits){}; inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; @@ -97,5 +98,4 @@ struct alignas(2) BFloat16 { } // namespace c10 - #include diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index b3b84850b2d..d978f32cd00 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -251,7 +251,8 @@ std::string get_backtrace( return stream.str(); #elif defined(_MSC_VER) // !SUPPORTS_BACKTRACE // This backtrace retrieval is implemented on Windows via the Windows - // API using `CaptureStackBackTrace`, `SymFromAddr` and `SymGetLineFromAddr64`. + // API using `CaptureStackBackTrace`, `SymFromAddr` and + // `SymGetLineFromAddr64`. // https://stackoverflow.com/questions/5693192/win32-backtrace-from-c-code // https://stackoverflow.com/questions/26398064/counterpart-to-glibcs-backtrace-and-backtrace-symbols-on-windows // https://docs.microsoft.com/en-us/windows/win32/debug/capturestackbacktrace @@ -313,7 +314,8 @@ std::string get_backtrace( << back_trace[i_frame] << std::dec; if (with_symbol) { stream << std::setfill('0') << std::setw(16) << std::uppercase << std::hex - << p_symbol->Address << std::dec << " " << module << "!" << p_symbol->Name; + << p_symbol->Address << std::dec << " " << module << "!" + << p_symbol->Name; } else { stream << " " << module << "!"; } diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h index 3e67169345c..6f7c4b9a1d7 100644 --- a/c10/util/Bitset.h +++ b/c10/util/Bitset.h @@ -21,15 +21,15 @@ namespace utils { * to exist is that std::bitset misses a find_first_set() method. */ struct bitset final { -private: - #if defined(_MSC_VER) - // MSVCs _BitScanForward64 expects int64_t - using bitset_type = int64_t; - #else - // POSIX ffsll expects long long int - using bitset_type = long long int; - #endif -public: + private: +#if defined(_MSC_VER) + // MSVCs _BitScanForward64 expects int64_t + using bitset_type = int64_t; +#else + // POSIX ffsll expects long long int + using bitset_type = long long int; +#endif + public: static constexpr size_t NUM_BITS() { return 8 * sizeof(bitset_type); } @@ -72,37 +72,38 @@ public: } } -private: + private: // Return the index of the first set bit. The returned index is one-indexed - // (i.e. if the very first bit is set, this function returns '1'), and a return - // of '0' means that there was no bit set. + // (i.e. if the very first bit is set, this function returns '1'), and a + // return of '0' means that there was no bit set. size_t find_first_set() const { - #if defined(_MSC_VER) && defined(_M_X64) - unsigned long result; - bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); +#if defined(_MSC_VER) && defined(_M_X64) + unsigned long result; + bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); + if (!has_bits_set) { + return 0; + } + return result + 1; +#elif defined(_MSC_VER) && defined(_M_IX86) + unsigned long result; + if (static_cast(bitset_) != 0) { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_))); if (!has_bits_set) { return 0; } return result + 1; - #elif defined(_MSC_VER) && defined(_M_IX86) - unsigned long result; - if (static_cast(bitset_) != 0) { - bool has_bits_set = (0 != _BitScanForward(&result, static_cast(bitset_))); - if (!has_bits_set) { - return 0; - } - return result + 1; + } else { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_ >> 32))); + if (!has_bits_set) { + return 32; } - else { - bool has_bits_set = (0 != _BitScanForward(&result, static_cast(bitset_ >> 32))); - if (!has_bits_set) { - return 32; - } - return result + 33; - } - #else - return __builtin_ffsll(bitset_); - #endif + return result + 33; + } +#else + return __builtin_ffsll(bitset_); +#endif } friend bool operator==(bitset lhs, bitset rhs) noexcept { diff --git a/c10/util/C++17.h b/c10/util/C++17.h index fa872c89592..90b04d896c9 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -2,48 +2,51 @@ #ifndef C10_UTIL_CPP17_H_ #define C10_UTIL_CPP17_H_ -#include -#include +#include +#include +#include #include #include #include -#include -#include -#include +#include +#include #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ - __GNUC__ < 5 -#error "You're trying to build PyTorch with a too old version of GCC. We need GCC 5 or later." + __GNUC__ < 5 +#error \ + "You're trying to build PyTorch with a too old version of GCC. We need GCC 5 or later." #endif #if defined(__clang__) && __clang_major__ < 4 -#error "You're trying to build PyTorch with a too old version of Clang. We need Clang 4 or later." +#error \ + "You're trying to build PyTorch with a too old version of Clang. We need Clang 4 or later." #endif -#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201402L)) || (!defined(_MSC_VER) && __cplusplus < 201402L) +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201402L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201402L) #error You need C++14 to compile PyTorch #endif #if defined(_WIN32) && (defined(min) || defined(max)) -# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows #endif /* * This header adds some polyfills with C++17 functionality */ -namespace c10 { namespace guts { - +namespace c10 { +namespace guts { template -typename std::enable_if::value && !std::is_array::value && std::is_base_of::value, std::unique_ptr>::type +typename std::enable_if< + !std::is_array::value && !std::is_array::value && + std::is_base_of::value, + std::unique_ptr>::type make_unique_base(Args&&... args) { return std::unique_ptr(new Child(std::forward(args)...)); } - - - #if defined(__cpp_lib_logical_traits) && !(defined(_MSC_VER) && _MSC_VER < 1920) template @@ -58,42 +61,49 @@ using negation = std::negation; #else // Implementation taken from http://en.cppreference.com/w/cpp/types/conjunction -template struct conjunction : std::true_type { }; -template struct conjunction : B1 { }; -template +template +struct conjunction : std::true_type {}; +template +struct conjunction : B1 {}; +template struct conjunction : std::conditional_t, B1> {}; // Implementation taken from http://en.cppreference.com/w/cpp/types/disjunction -template struct disjunction : std::false_type { }; -template struct disjunction : B1 { }; -template +template +struct disjunction : std::false_type {}; +template +struct disjunction : B1 {}; +template struct disjunction - : std::conditional_t> { }; + : std::conditional_t> {}; -// Implementation taken from http://en.cppreference.com/w/cpp/types/integral_constant +// Implementation taken from +// http://en.cppreference.com/w/cpp/types/integral_constant template using bool_constant = std::integral_constant; // Implementation taken from http://en.cppreference.com/w/cpp/types/negation -template -struct negation : bool_constant { }; +template +struct negation : bool_constant {}; #endif - - - #ifdef __cpp_lib_void_t -template using void_t = std::void_t; +template +using void_t = std::void_t; #else // Implementation taken from http://en.cppreference.com/w/cpp/types/void_t // (it takes CWG1558 into account and also works for older compilers) -template struct make_void { typedef void type;}; -template using void_t = typename make_void::type; +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; #endif @@ -113,34 +123,44 @@ CUDA_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) { #else -// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but modified) -// TODO This is an incomplete implementation of std::apply, not working for member functions. +// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but +// modified) +// TODO This is an incomplete implementation of std::apply, not working for +// member functions. namespace detail { template #if defined(_MSC_VER) -// MSVC has a problem with the decltype() return type, but it also doesn't need it -C10_HOST_DEVICE constexpr auto apply_impl(F&& f, Tuple&& t, std::index_sequence) +// MSVC has a problem with the decltype() return type, but it also doesn't need +// it +C10_HOST_DEVICE constexpr auto apply_impl( + F&& f, + Tuple&& t, + std::index_sequence) #else // GCC/Clang need the decltype() return type -CUDA_HOST_DEVICE constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, std::index_sequence) +CUDA_HOST_DEVICE constexpr decltype(auto) apply_impl( + F&& f, + Tuple&& t, + std::index_sequence) #endif { - return std::forward(f)(std::get(std::forward(t))...); + return std::forward(f)(std::get(std::forward(t))...); } -} // namespace detail +} // namespace detail template CUDA_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) { - return detail::apply_impl( - std::forward(f), std::forward(t), - std::make_index_sequence>::value>{}); + return detail::apply_impl( + std::forward(f), + std::forward(t), + std::make_index_sequence< + std::tuple_size>::value>{}); } #endif #undef CUDA_HOST_DEVICE - template typename std::enable_if< std::is_member_pointer::type>::value, @@ -157,63 +177,94 @@ invoke(Functor&& f, Args&&... args) { return std::forward(f)(std::forward(args)...); } - - - - - namespace detail { struct _identity final { - template + template using type_identity = T; - template + template decltype(auto) operator()(T&& arg) { return std::forward(arg); } }; -template +template struct function_takes_identity_argument : std::false_type {}; #if defined(_MSC_VER) -// For some weird reason, MSVC shows a compiler error when using guts::void_t instead of std::void_t. -// But we're only building on MSVC versions that have std::void_t, so let's just use that one. -template -struct function_takes_identity_argument()(_identity()))>> : std::true_type {}; +// For some weird reason, MSVC shows a compiler error when using guts::void_t +// instead of std::void_t. But we're only building on MSVC versions that have +// std::void_t, so let's just use that one. +template +struct function_takes_identity_argument< + Func, + std::void_t()(_identity()))>> : std::true_type { +}; #else -template -struct function_takes_identity_argument()(_identity()))>> : std::true_type {}; +template +struct function_takes_identity_argument< + Func, + void_t()(_identity()))>> : std::true_type {}; #endif -template +template struct _if_constexpr; -template<> +template <> struct _if_constexpr final { - template::value, void*> = nullptr> - static decltype(auto) call(ThenCallback&& thenCallback, ElseCallback&& /* elseCallback */) { - // The _identity instance passed in can be used to delay evaluation of an expression, - // because the compiler can't know that it's just the identity we're passing in. + template < + class ThenCallback, + class ElseCallback, + std::enable_if_t< + function_takes_identity_argument::value, + void*> = nullptr> + static decltype(auto) call( + ThenCallback&& thenCallback, + ElseCallback&& /* elseCallback */) { + // The _identity instance passed in can be used to delay evaluation of an + // expression, because the compiler can't know that it's just the identity + // we're passing in. return thenCallback(_identity()); } - template::value, void*> = nullptr> - static decltype(auto) call(ThenCallback&& thenCallback, ElseCallback&& /* elseCallback */) { + template < + class ThenCallback, + class ElseCallback, + std::enable_if_t< + !function_takes_identity_argument::value, + void*> = nullptr> + static decltype(auto) call( + ThenCallback&& thenCallback, + ElseCallback&& /* elseCallback */) { return thenCallback(); } }; -template<> +template <> struct _if_constexpr final { - template::value, void*> = nullptr> - static decltype(auto) call(ThenCallback&& /* thenCallback */, ElseCallback&& elseCallback) { - // The _identity instance passed in can be used to delay evaluation of an expression, - // because the compiler can't know that it's just the identity we're passing in. + template < + class ThenCallback, + class ElseCallback, + std::enable_if_t< + function_takes_identity_argument::value, + void*> = nullptr> + static decltype(auto) call( + ThenCallback&& /* thenCallback */, + ElseCallback&& elseCallback) { + // The _identity instance passed in can be used to delay evaluation of an + // expression, because the compiler can't know that it's just the identity + // we're passing in. return elseCallback(_identity()); } - template::value, void*> = nullptr> - static decltype(auto) call(ThenCallback&& /* thenCallback */, ElseCallback&& elseCallback) { + template < + class ThenCallback, + class ElseCallback, + std::enable_if_t< + !function_takes_identity_argument::value, + void*> = nullptr> + static decltype(auto) call( + ThenCallback&& /* thenCallback */, + ElseCallback&& elseCallback) { return elseCallback(); } }; @@ -249,38 +300,47 @@ struct _if_constexpr final { * template * int func(T t) { * return if_constexpr::value>( - * [&](auto _) { return _(t).value; }, // this code is invalid for T == MyClass2, so a regular non-constexpr if statement wouldn't compile - * [&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1 + * [&](auto _) { return _(t).value; }, // this code is invalid for T == + * MyClass2, so a regular non-constexpr if statement wouldn't compile + * [&](auto _) { return _(t).val; } // this code is invalid for T == + * MyClass1 * ); * } * - * Note: The _ argument passed in Example 3 is the identity function, i.e. it does nothing. - * It is used to force the compiler to delay type checking, because the compiler - * doesn't know what kind of _ is passed in. Without it, the compiler would fail - * when you try to access t.value but the member doesn't exist. + * Note: The _ argument passed in Example 3 is the identity function, i.e. it + * does nothing. It is used to force the compiler to delay type checking, + * because the compiler doesn't know what kind of _ is passed in. Without it, + * the compiler would fail when you try to access t.value but the member doesn't + * exist. * - * Note: In Example 3, both branches return int, so func() returns int. This is not necessary. - * If func() had a return type of "auto", then both branches could return different - * types, say func() could return int and func() could return string. + * Note: In Example 3, both branches return int, so func() returns int. This is + * not necessary. If func() had a return type of "auto", then both branches + * could return different types, say func() could return int and + * func() could return string. * - * Note: if_constexpr is *eager* w.r.t. template expansion - meaning this - * polyfill does not behave like a true "if statement at compilation time". - * The `_` trick above only defers typechecking, which happens after templates - * have been expanded. (Of course this is all that's necessary for many use cases). + * Note: if_constexpr is *eager* w.r.t. template expansion - meaning + * this polyfill does not behave like a true "if statement at compilation time". + * The `_` trick above only defers typechecking, which happens after + * templates have been expanded. (Of course this is all that's necessary for + * many use cases). */ -template -decltype(auto) if_constexpr(ThenCallback&& thenCallback, ElseCallback&& elseCallback) { +template +decltype(auto) if_constexpr( + ThenCallback&& thenCallback, + ElseCallback&& elseCallback) { #if defined(__cpp_if_constexpr) - // If we have C++17, just use it's "if constexpr" feature instead of wrapping it. - // This will give us better error messages. - if constexpr(Condition) { - if constexpr (detail::function_takes_identity_argument::value) { + // If we have C++17, just use it's "if constexpr" feature instead of wrapping + // it. This will give us better error messages. + if constexpr (Condition) { + if constexpr (detail::function_takes_identity_argument< + ThenCallback>::value) { return ::std::forward(thenCallback)(detail::_identity()); } else { return ::std::forward(thenCallback)(); } } else { - if constexpr (detail::function_takes_identity_argument::value) { + if constexpr (detail::function_takes_identity_argument< + ElseCallback>::value) { return ::std::forward(elseCallback)(detail::_identity()); } else { return ::std::forward(elseCallback)(); @@ -288,18 +348,20 @@ decltype(auto) if_constexpr(ThenCallback&& thenCallback, ElseCallback&& elseCall } #else // C++14 implementation of if constexpr - return detail::_if_constexpr::call(::std::forward(thenCallback), - ::std::forward(elseCallback)); + return detail::_if_constexpr::call( + ::std::forward(thenCallback), + ::std::forward(elseCallback)); #endif } -template +template decltype(auto) if_constexpr(ThenCallback&& thenCallback) { #if defined(__cpp_if_constexpr) - // If we have C++17, just use it's "if constexpr" feature instead of wrapping it. - // This will give us better error messages. - if constexpr(Condition) { - if constexpr (detail::function_takes_identity_argument::value) { + // If we have C++17, just use it's "if constexpr" feature instead of wrapping + // it. This will give us better error messages. + if constexpr (Condition) { + if constexpr (detail::function_takes_identity_argument< + ThenCallback>::value) { return ::std::forward(thenCallback)(detail::_identity()); } else { return ::std::forward(thenCallback)(); @@ -307,45 +369,53 @@ decltype(auto) if_constexpr(ThenCallback&& thenCallback) { } #else // C++14 implementation of if constexpr - return if_constexpr(::std::forward(thenCallback), [] (auto) {}); + return if_constexpr( + ::std::forward(thenCallback), [](auto) {}); #endif } - - -// GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's define it. +// GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's +// define it. namespace detail { class DummyClassForToString final {}; -}}} +} // namespace detail +} // namespace guts +} // namespace c10 namespace std { -// We use SFINAE to detect if std::to_string exists for a type, but that only works -// if the function name is defined. So let's define a std::to_string for a dummy type. -// If you're getting an error here saying that this overload doesn't match your -// std::to_string() call, then you're calling std::to_string() but should be calling -// c10::guts::to_string(). -inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; } - +// We use SFINAE to detect if std::to_string exists for a type, but that only +// works if the function name is defined. So let's define a std::to_string for a +// dummy type. If you're getting an error here saying that this overload doesn't +// match your std::to_string() call, then you're calling std::to_string() but +// should be calling c10::guts::to_string(). +inline std::string to_string(c10::guts::detail::DummyClassForToString) { + return ""; } -namespace c10 { namespace guts { namespace detail { -template +} // namespace std +namespace c10 { +namespace guts { +namespace detail { + +template struct to_string_ final { - static std::string call(T value) { - std::ostringstream str; - str << value; - return str.str(); - } + static std::string call(T value) { + std::ostringstream str; + str << value; + return str.str(); + } }; // If a std::to_string exists, use that instead -template -struct to_string_()))>> final { - static std::string call(T value) { - return std::to_string(value); - } +template +struct to_string_()))>> + final { + static std::string call(T value) { + return std::to_string(value); + } }; -} -template inline std::string to_string(T value) { - return detail::to_string_::call(value); +} // namespace detail +template +inline std::string to_string(T value) { + return detail::to_string_::call(value); } template @@ -358,6 +428,7 @@ constexpr const T& max(const T& a, const T& b) { return (a < b) ? b : a; } -}} +} // namespace guts +} // namespace c10 #endif // C10_UTIL_CPP17_H_ diff --git a/c10/util/Deprecated.h b/c10/util/Deprecated.h index 545a9b1ab90..88440a0242e 100644 --- a/c10/util/Deprecated.h +++ b/c10/util/Deprecated.h @@ -16,26 +16,25 @@ // }; // NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses -// the "__declspec(deprecated)" implementation and not the C++14 "[[deprecated]]" -// attribute. We tried enabling "[[deprecated]]" for C++14 on MSVC, but -// ran into issues with some older MSVC versions. +// the "__declspec(deprecated)" implementation and not the C++14 +// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on +// MSVC, but ran into issues with some older MSVC versions. #if (defined(__cplusplus) && __cplusplus >= 201402L) -# define C10_DEPRECATED [[deprecated]] -# define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] +#define C10_DEPRECATED [[deprecated]] +#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] #elif defined(__GNUC__) -# define C10_DEPRECATED __attribute__((deprecated)) +#define C10_DEPRECATED __attribute__((deprecated)) // TODO Is there some way to implement this? -# define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) +#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) #elif defined(_MSC_VER) -# define C10_DEPRECATED __declspec(deprecated) -# define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#define C10_DEPRECATED __declspec(deprecated) +#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) #else -# warning "You need to implement C10_DEPRECATED for this compiler" -# define C10_DEPRECATED +#warning "You need to implement C10_DEPRECATED for this compiler" +#define C10_DEPRECATED #endif - // Sample usage: // // C10_DEFINE_DEPRECATED_USING(BadType, int) @@ -48,7 +47,8 @@ // many compilers. #if defined(__has_cpp_attribute) #if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName [[deprecated]] = TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; #endif #endif @@ -61,18 +61,21 @@ // // So we just turn the macro off in this case. #if defined(C10_DEFINE_DEPRECATED_USING) -# undef C10_DEFINE_DEPRECATED_USING +#undef C10_DEFINE_DEPRECATED_USING #endif -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName = TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; #else // [[deprecated]] does work in windows without nvcc, though msc doesn't support -// `__has_cpp_attribute` when c++14 is supported, otherwise __declspec(deprecated) -// is used as the alternative. +// `__has_cpp_attribute` when c++14 is supported, otherwise +// __declspec(deprecated) is used as the alternative. #ifndef C10_DEFINE_DEPRECATED_USING #if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName [[deprecated]] = TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; #else -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName = __declspec(deprecated) TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = __declspec(deprecated) TypeThingy; #endif #endif #endif @@ -84,14 +87,16 @@ // attribute when not cuda, and when using a GCC compiler that doesn't support // the c++14 syntax we checked for above (available in __GNUC__ >= 5) #if !defined(__CUDACC__) -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName __attribute__((deprecated)) = TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName __attribute__((deprecated)) = TypeThingy; #else // using cuda + gcc < 5, neither deprecated syntax is available so turning off. -# define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) using TypeName = TypeThingy; +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; #endif #endif -#if ! defined(C10_DEFINE_DEPRECATED_USING) -# warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" -# define C10_DEFINE_DEPRECATED_USING +#if !defined(C10_DEFINE_DEPRECATED_USING) +#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" +#define C10_DEFINE_DEPRECATED_USING #endif diff --git a/c10/util/Flags.h b/c10/util/Flags.h index b4352510c99..d3f750016d7 100644 --- a/c10/util/Flags.h +++ b/c10/util/Flags.h @@ -115,7 +115,7 @@ namespace gflags = google; // flags defined in C10. This is done via a global reference, so the flag // itself is not duplicated - under the hood it is the same global gflags flag. #define C10_GFLAGS_DEF_WRAPPER(type, real_type, name, default_value, help_str) \ - DEFINE_##type(name, default_value, help_str); \ + DEFINE_##type(name, default_value, help_str); #define C10_DEFINE_int(name, default_value, help_str) \ C10_GFLAGS_DEF_WRAPPER(int32, gflags::int32, name, default_value, help_str) @@ -131,8 +131,7 @@ namespace gflags = google; C10_GFLAGS_DEF_WRAPPER(string, ::fLS::clstring, name, default_value, help_str) // DECLARE_typed_var should be used in header files and in the global namespace. -#define C10_GFLAGS_DECLARE_WRAPPER(type, real_type, name) \ - DECLARE_##type(name); \ +#define C10_GFLAGS_DECLARE_WRAPPER(type, real_type, name) DECLARE_##type(name); #define C10_DECLARE_int(name) \ C10_GFLAGS_DECLARE_WRAPPER(int32, gflags::int32, name) diff --git a/c10/util/FunctionRef.h b/c10/util/FunctionRef.h index 04b0ce09b65..929fd6552c2 100644 --- a/c10/util/FunctionRef.h +++ b/c10/util/FunctionRef.h @@ -30,40 +30,43 @@ namespace c10 { /// /// This class does not own the callable, so it is not in general safe to store /// a function_ref. -template class function_ref; +template +class function_ref; -template +template class function_ref { -Ret (*callback)(intptr_t callable, Params ...params) = nullptr; -intptr_t callable; + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable; -template -static Ret callback_fn(intptr_t callable, Params ...params) { + template + static Ret callback_fn(intptr_t callable, Params... params) { return (*reinterpret_cast(callable))( std::forward(params)...); -} + } -public: -function_ref() = default; -function_ref(std::nullptr_t) {} + public: + function_ref() = default; + function_ref(std::nullptr_t) {} -template -function_ref(Callable &&callable, - typename std::enable_if< - !std::is_same::type, - function_ref>::value>::type * = nullptr, - typename std::enable_if< - std::is_convertible< - typename std::result_of::type, - Ret>::value>::type * = nullptr) - : callback(callback_fn::type>), + template + function_ref( + Callable&& callable, + typename std::enable_if::type, + function_ref>::value>::type* = nullptr, + typename std::enable_if::type, + Ret>::value>::type* = nullptr) + : callback(callback_fn::type>), callable(reinterpret_cast(&callable)) {} -Ret operator()(Params ...params) const { + Ret operator()(Params... params) const { return callback(callable, std::forward(params)...); -} + } -operator bool() const { return callback; } + operator bool() const { + return callback; + } }; -} +} // namespace c10 diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index 2a39e007b33..554787474e9 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #ifdef __CUDACC__ #include @@ -48,7 +48,7 @@ inline C10_HOST_DEVICE Half::operator __half() const { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ (defined(__clang__) && defined(__CUDA__)) inline __device__ Half __ldg(const Half* ptr) { - return __ldg(reinterpret_cast(ptr)); + return __ldg(reinterpret_cast(ptr)); } #endif @@ -66,12 +66,14 @@ inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { return static_cast(a) * static_cast(b); } -inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / static_cast(b); } inline C10_HOST_DEVICE Half operator-(const Half& a) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(__HIP_DEVICE_COMPILE__) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) return __hneg(a); #else return -static_cast(a); @@ -109,7 +111,8 @@ inline C10_HOST_DEVICE float operator-(Half a, float b) { inline C10_HOST_DEVICE float operator*(Half a, float b) { return static_cast(a) * b; } -inline C10_HOST_DEVICE float operator/(Half a, float b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / b; } @@ -122,7 +125,8 @@ inline C10_HOST_DEVICE float operator-(float a, Half b) { inline C10_HOST_DEVICE float operator*(float a, Half b) { return a * static_cast(b); } -inline C10_HOST_DEVICE float operator/(float a, Half b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { return a / static_cast(b); } @@ -150,7 +154,8 @@ inline C10_HOST_DEVICE double operator-(Half a, double b) { inline C10_HOST_DEVICE double operator*(Half a, double b) { return static_cast(a) * b; } -inline C10_HOST_DEVICE double operator/(Half a, double b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / b; } @@ -163,7 +168,8 @@ inline C10_HOST_DEVICE double operator-(double a, Half b) { inline C10_HOST_DEVICE double operator*(double a, Half b) { return a * static_cast(b); } -inline C10_HOST_DEVICE double operator/(double a, Half b) __ubsan_ignore_float_divide_by_zero__ { +inline C10_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { return a / static_cast(b); } diff --git a/c10/util/Half.h b/c10/util/Half.h index 86c7d0a7d47..c22db1fab24 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -26,8 +26,8 @@ #endif #include -#include #include +#include #include #include #include @@ -54,285 +54,313 @@ namespace c10 { namespace detail { - C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) { - #if defined(__OPENCL_VERSION__) - return as_float(w); - #elif defined(__CUDA_ARCH__) - return __uint_as_float((unsigned int)w); - #elif defined(__INTEL_COMPILER) - return _castu32_f32(w); - #else - union { - uint32_t as_bits; - float as_value; - } fp32 = {w}; - return fp32.as_value; - #endif - } - - C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) { - #if defined(__OPENCL_VERSION__) - return as_uint(f); - #elif defined(__CUDA_ARCH__) - return (uint32_t)__float_as_uint(f); - #elif defined(__INTEL_COMPILER) - return _castf32_u32(f); - #else - union { - float as_value; - uint32_t as_bits; - } fp32 = {f}; - return fp32.as_bits; - #endif - } - - /* - * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to - * a 32-bit floating-point number in IEEE single-precision format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ - inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { - /* - * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: - * +---+-----+------------+-------------------+ - * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 31 26-30 16-25 0-15 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. - */ - const uint32_t w = (uint32_t) h << 16; - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = w & UINT32_C(0x80000000); - /* - * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: - * - * +---+-----+------------+-------------------+ - * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 30 27-31 17-26 0-16 - */ - const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); - /* - * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. - * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. - * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift - * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the - * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). - */ -#ifdef _MSC_VER - unsigned long nonsign_bsr; - _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); - uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); #else - uint32_t renorm_shift = __builtin_clz(nonsign); + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; #endif - renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; - /* - * Iff half-precision number has exponent of 15, the addition overflows - * it into bit 31, and the subsequent shift turns the high 9 bits - * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number - * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise - */ - const int32_t inf_nan_mask = - ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); - /* - * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 - * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 - * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == - * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) - * 0x00000000 otherwise - */ - const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; - /* - * 1. Shift nonsign left by renorm_shift to normalize it (if the input - * was denormal) - * 2. Shift nonsign right by 3 so the exponent (5 bits originally) - * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high - * bits of the 23-bit mantissa of IEEE single-precision number. - * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the - * different in exponent bias (0x7F for single-precision number less 0xF - * for half-precision number). - * 4. Subtract renorm_shift from the exponent (starting at bit 23) to - * account for renormalization. As renorm_shift is less than 0x70, this - * can be combined with step 3. - * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the - * input was NaN or infinity. - * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent - * into zero if the input was zero. - * 7. Combine with the sign of the input number. - */ - return sign | - ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); - } +} + +C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +#endif +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; /* - * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to - * a 32-bit floating-point number in IEEE single-precision format. + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: * - * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) - * floating-point operations and bitcasts between integer and floating-point variables. + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. */ - inline float fp16_ieee_to_fp32_value(uint16_t h) { - /* - * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: - * +---+-----+------------+-------------------+ - * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 31 26-30 16-25 0-15 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. - */ - const uint32_t w = (uint32_t) h << 16; - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = w & UINT32_C(0x80000000); - /* - * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: - * - * +-----+------------+---------------------+ - * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| - * +-----+------------+---------------------+ - * Bits 27-31 17-26 0-16 - */ - const uint32_t two_w = w + w; - - /* - * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent - * of a single-precision floating-point number: - * - * S|Exponent | Mantissa - * +-+---+-----+------------+----------------+ - * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| - * +-+---+-----+------------+----------------+ - * Bits | 23-31 | 0-22 - * - * Next, there are some adjustments to the exponent: - * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision - * formats (0x7F - 0xF = 0x70) - * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. - * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent - * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: - * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested - * by the difference in the exponent bias (see above). - * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of - * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. - * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least - * partially IEEE754-compliant implementations. - * - * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not - * operate on denormal inputs, and do not produce denormal results. - */ - const uint32_t exp_offset = UINT32_C(0xE0) << 23; - // const float exp_scale = 0x1.0p-112f; - uint32_t scale_bits = (uint32_t) 15 << 23; - float exp_scale_val; - std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); - const float exp_scale = exp_scale_val; - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - /* - * Convert denormalized half-precision inputs into single-precision results (always normalized). - * Zero inputs are also handled here. - * - * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. - * First, we shift mantissa into bits 0-9 of the 32-bit word. - * - * zeros | mantissa - * +---------------------------+------------+ - * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| - * +---------------------------+------------+ - * Bits 10-31 0-9 - * - * Now, remember that denormalized half-precision numbers are represented as: - * FP16 = mantissa * 2**(-24). - * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input - * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). - * A normalized single-precision floating-point number is represented as: - * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) - * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision - * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same amount. - * - * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number - * is zero, the constructed single-precision number has the value of - * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 - * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of - * the input half-precision number. - */ - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - /* - * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the - * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the - * input is either a denormal number, or zero. - * - Combine the result of conversion of exponent and mantissa with the sign of the input number. - */ - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); - } + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; /* - * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in - * IEEE half-precision format, in bit representation. + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. * - * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) - * floating-point operations and bitcasts between integer and floating-point variables. + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructud single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. */ - inline uint16_t fp16_ieee_from_fp32_value(float f) { - // const float scale_to_inf = 0x1.0p+112f; - // const float scale_to_zero = 0x1.0p-110f; - uint32_t scale_to_inf_bits = (uint32_t) 239 << 23; - uint32_t scale_to_zero_bits = (uint32_t) 17 << 23; - float scale_to_inf_val, scale_to_zero_val; - std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); - std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); - const float scale_to_inf = scale_to_inf_val; - const float scale_to_zero = scale_to_zero_val; + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy( + &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; #if defined(_MSC_VER) && _MSC_VER == 1916 - float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; #else - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; #endif - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return static_cast( - (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) - ); + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); } + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +} + } // namespace detail struct alignas(2) Half { @@ -362,7 +390,7 @@ struct alignas(2) Half { // This is just a placeholder for whatever complex representation we // end up deciding to use for half-precision complex numbers. -template<> +template <> struct alignas(4) complex { using value_type = Half; Half real_; @@ -389,10 +417,10 @@ struct alignas(4) complex { // C4804: unsafe use of type 'bool' in operation // It can be addressed by disabling the following warning. #ifdef _MSC_VER -#pragma warning( push ) -#pragma warning( disable : 4146 ) -#pragma warning( disable : 4804 ) -#pragma warning( disable : 4018 ) +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) #endif // The overflow checks may involve float to int conversion which may @@ -416,8 +444,10 @@ typename std::enable_if::value, bool>::type overflows( // skip isnan and isinf check for integral types template -typename std::enable_if::value && !std::is_same::value, bool>::type overflows( - From f) { +typename std::enable_if< + std::is_integral::value && !std::is_same::value, + bool>::type +overflows(From f) { using limit = std::numeric_limits::type>; if (!limit::is_signed && std::numeric_limits::is_signed) { // allow for negative numbers to wrap using two's complement arithmetic. @@ -448,12 +478,11 @@ overflows(From f) { #endif #ifdef _MSC_VER -#pragma warning( pop ) +#pragma warning(pop) #endif template -typename std::enable_if::value, bool>::type overflows( - From f) { +typename std::enable_if::value, bool>::type overflows(From f) { // casts from complex to real are considered to overflow if the // imaginary component is non-zero if (!is_complex::value && f.imag() != 0) { diff --git a/c10/util/IdWrapper.h b/c10/util/IdWrapper.h index dc28141e539..a22a60cb9fc 100644 --- a/c10/util/IdWrapper.h +++ b/c10/util/IdWrapper.h @@ -66,12 +66,12 @@ class IdWrapper { } // namespace c10 -#define C10_DEFINE_HASH_FOR_IDWRAPPER(ClassName)\ - namespace std { \ - template <> \ - struct hash { \ - size_t operator()(ClassName x) const { \ - return hash_value(x); \ - } \ - }; \ +#define C10_DEFINE_HASH_FOR_IDWRAPPER(ClassName) \ + namespace std { \ + template <> \ + struct hash { \ + size_t operator()(ClassName x) const { \ + return hash_value(x); \ + } \ + }; \ } diff --git a/c10/util/LeftRight.h b/c10/util/LeftRight.h index bb93ec512c7..4081540c58a 100644 --- a/c10/util/LeftRight.h +++ b/c10/util/LeftRight.h @@ -1,30 +1,31 @@ +#include +#include #include #include #include #include -#include -#include namespace c10 { namespace detail { struct IncrementRAII final { -public: - explicit IncrementRAII(std::atomic *counter): _counter(counter) { - _counter->fetch_add(1); - } + public: + explicit IncrementRAII(std::atomic* counter) : _counter(counter) { + _counter->fetch_add(1); + } - ~IncrementRAII() { - _counter->fetch_sub(1); - } -private: - std::atomic *_counter; + ~IncrementRAII() { + _counter->fetch_sub(1); + } - C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); + private: + std::atomic* _counter; + + C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); }; -} +} // namespace detail // LeftRight wait-free readers synchronization primitive // https://hal.archives-ouvertes.fr/hal-01207881/document @@ -48,135 +49,142 @@ private: // template class LeftRight final { -public: - template - explicit LeftRight(const Args& ...args) - : _counters{{{0}, {0}}} - , _foregroundCounterIndex(0) - , _foregroundDataIndex(0) - , _data{{T{args...}, T{args...}}} - , _writeMutex() - {} + public: + template + explicit LeftRight(const Args&... args) + : _counters{{{0}, {0}}}, + _foregroundCounterIndex(0), + _foregroundDataIndex(0), + _data{{T{args...}, T{args...}}}, + _writeMutex() {} - // Copying and moving would not be threadsafe. - // Needs more thought and careful design to make that work. - LeftRight(const LeftRight&) = delete; - LeftRight(LeftRight&&) noexcept = delete; - LeftRight& operator=(const LeftRight&) = delete; - LeftRight& operator=(LeftRight&&) noexcept= delete; + // Copying and moving would not be threadsafe. + // Needs more thought and careful design to make that work. + LeftRight(const LeftRight&) = delete; + LeftRight(LeftRight&&) noexcept = delete; + LeftRight& operator=(const LeftRight&) = delete; + LeftRight& operator=(LeftRight&&) noexcept = delete; - ~LeftRight() { - // wait until any potentially running writers are finished - { - std::unique_lock lock(_writeMutex); - } + ~LeftRight() { + // wait until any potentially running writers are finished + { std::unique_lock lock(_writeMutex); } - // wait until any potentially running readers are finished - while (_counters[0].load() != 0 || _counters[1].load() != 0) { - std::this_thread::yield(); - } + // wait until any potentially running readers are finished + while (_counters[0].load() != 0 || _counters[1].load() != 0) { + std::this_thread::yield(); } + } - template - auto read(F&& readFunc) const -> typename std::result_of::type { - detail::IncrementRAII _increment_counter(&_counters[_foregroundCounterIndex.load()]); + template + auto read(F&& readFunc) const -> typename std::result_of::type { + detail::IncrementRAII _increment_counter( + &_counters[_foregroundCounterIndex.load()]); - return readFunc(_data[_foregroundDataIndex.load()]); + return readFunc(_data[_foregroundDataIndex.load()]); + } + + // Throwing an exception in writeFunc is ok but causes the state to be either + // the old or the new state, depending on if the first or the second call to + // writeFunc threw. + template + auto write(F&& writeFunc) -> typename std::result_of::type { + std::unique_lock lock(_writeMutex); + + return _write(writeFunc); + } + + private: + template + auto _write(const F& writeFunc) -> typename std::result_of::type { + /* + * Assume, A is in background and B in foreground. In simplified terms, we + * want to do the following: + * 1. Write to A (old background) + * 2. Switch A/B + * 3. Write to B (new background) + * + * More detailed algorithm (explanations on why this is important are below + * in code): + * 1. Write to A + * 2. Switch A/B data pointers + * 3. Wait until A counter is zero + * 4. Switch A/B counters + * 5. Wait until B counter is zero + * 6. Write to B + */ + + auto localDataIndex = _foregroundDataIndex.load(); + + // 1. Write to A + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + + // 2. Switch A/B data pointers + localDataIndex = localDataIndex ^ 1; + _foregroundDataIndex = localDataIndex; + + /* + * 3. Wait until A counter is zero + * + * In the previous write run, A was foreground and B was background. + * There was a time after switching _foregroundDataIndex (B to foreground) + * and before switching _foregroundCounterIndex, in which new readers could + * have read B but incremented A's counter. + * + * In this current run, we just switched _foregroundDataIndex (A back to + * foreground), but before writing to the new background B, we have to make + * sure A's counter was zero briefly, so all these old readers are gone. + */ + auto localCounterIndex = _foregroundCounterIndex.load(); + _waitForBackgroundCounterToBeZero(localCounterIndex); + + /* + * 4. Switch A/B counters + * + * Now that we know all readers on B are really gone, we can switch the + * counters and have new readers increment A's counter again, which is the + * correct counter since they're reading A. + */ + localCounterIndex = localCounterIndex ^ 1; + _foregroundCounterIndex = localCounterIndex; + + /* + * 5. Wait until B counter is zero + * + * This waits for all the readers on B that came in while both data and + * counter for B was in foreground, i.e. normal readers that happened + * outside of that brief gap between switching data and counter. + */ + _waitForBackgroundCounterToBeZero(localCounterIndex); + + // 6. Write to B + return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + } + + template + auto _callWriteFuncOnBackgroundInstance( + const F& writeFunc, + uint8_t localDataIndex) -> typename std::result_of::type { + try { + return writeFunc(_data[localDataIndex ^ 1]); + } catch (...) { + // recover invariant by copying from the foreground instance + _data[localDataIndex ^ 1] = _data[localDataIndex]; + // rethrow + throw; } + } - // Throwing an exception in writeFunc is ok but causes the state to be either the old or the new state, - // depending on if the first or the second call to writeFunc threw. - template - auto write(F&& writeFunc) -> typename std::result_of::type { - std::unique_lock lock(_writeMutex); - - return _write(writeFunc); + void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { + while (_counters[counterIndex ^ 1].load() != 0) { + std::this_thread::yield(); } + } -private: - template - auto _write(const F& writeFunc) -> typename std::result_of::type { - /* - * Assume, A is in background and B in foreground. In simplified terms, we want to do the following: - * 1. Write to A (old background) - * 2. Switch A/B - * 3. Write to B (new background) - * - * More detailed algorithm (explanations on why this is important are below in code): - * 1. Write to A - * 2. Switch A/B data pointers - * 3. Wait until A counter is zero - * 4. Switch A/B counters - * 5. Wait until B counter is zero - * 6. Write to B - */ - - auto localDataIndex = _foregroundDataIndex.load(); - - // 1. Write to A - _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); - - // 2. Switch A/B data pointers - localDataIndex = localDataIndex ^ 1; - _foregroundDataIndex = localDataIndex; - - /* - * 3. Wait until A counter is zero - * - * In the previous write run, A was foreground and B was background. - * There was a time after switching _foregroundDataIndex (B to foreground) and before switching _foregroundCounterIndex, - * in which new readers could have read B but incremented A's counter. - * - * In this current run, we just switched _foregroundDataIndex (A back to foreground), but before writing to - * the new background B, we have to make sure A's counter was zero briefly, so all these old readers are gone. - */ - auto localCounterIndex = _foregroundCounterIndex.load(); - _waitForBackgroundCounterToBeZero(localCounterIndex); - - /* - * 4. Switch A/B counters - * - * Now that we know all readers on B are really gone, we can switch the counters and have new readers - * increment A's counter again, which is the correct counter since they're reading A. - */ - localCounterIndex = localCounterIndex ^ 1; - _foregroundCounterIndex = localCounterIndex; - - /* - * 5. Wait until B counter is zero - * - * This waits for all the readers on B that came in while both data and counter for B was in foreground, - * i.e. normal readers that happened outside of that brief gap between switching data and counter. - */ - _waitForBackgroundCounterToBeZero(localCounterIndex); - - // 6. Write to B - return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); - } - - template - auto _callWriteFuncOnBackgroundInstance(const F& writeFunc, uint8_t localDataIndex) -> typename std::result_of::type { - try { - return writeFunc(_data[localDataIndex ^ 1]); - } catch (...) { - // recover invariant by copying from the foreground instance - _data[localDataIndex ^ 1] = _data[localDataIndex]; - // rethrow - throw; - } - } - - void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { - while (_counters[counterIndex ^ 1].load() != 0) { - std::this_thread::yield(); - } - } - - mutable std::array, 2> _counters; - std::atomic _foregroundCounterIndex; - std::atomic _foregroundDataIndex; - std::array _data; - std::mutex _writeMutex; + mutable std::array, 2> _counters; + std::atomic _foregroundCounterIndex; + std::atomic _foregroundDataIndex; + std::array _data; + std::mutex _writeMutex; }; -} +} // namespace c10 diff --git a/c10/util/MathConstants.h b/c10/util/MathConstants.h index d90d5d7effe..e169293f8a8 100644 --- a/c10/util/MathConstants.h +++ b/c10/util/MathConstants.h @@ -9,22 +9,24 @@ namespace c10 { namespace detail { template C10_HOST_DEVICE inline constexpr T pi() { - return static_cast(3.14159265358979323846L); + return static_cast(3.14159265358979323846L); } -template<> +template <> C10_HOST_DEVICE inline constexpr BFloat16 pi() { - // According to https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values pi is encoded as 4049 + // According to + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values + // pi is encoded as 4049 return BFloat16(0x4049, BFloat16::from_bits()); } -template<> +template <> C10_HOST_DEVICE inline constexpr Half pi() { return Half(0x4248, Half::from_bits()); } } // namespace detail // TODO: Replace me with std::numbers::pi when C++20 is there -template +template constexpr T pi = c10::detail::pi(); } // namespace c10 diff --git a/c10/util/MaybeOwned.h b/c10/util/MaybeOwned.h index 49ca2186104..f4b48b613b0 100644 --- a/c10/util/MaybeOwned.h +++ b/c10/util/MaybeOwned.h @@ -68,20 +68,21 @@ class MaybeOwned final { }; /// Don't use this; use borrowed() instead. - explicit MaybeOwned(const owned_type& t) : isBorrowed_(true), borrow_(MaybeOwnedTraits::createBorrow(t)) {} + explicit MaybeOwned(const owned_type& t) + : isBorrowed_(true), borrow_(MaybeOwnedTraits::createBorrow(t)) {} /// Don't use this; use owned() instead. - explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible::value) - : isBorrowed_(false), own_(std::move(t)) {} + explicit MaybeOwned(T&& t) noexcept( + std::is_nothrow_move_constructible::value) + : isBorrowed_(false), own_(std::move(t)) {} /// Don't use this; use owned() instead. template explicit MaybeOwned(in_place_t, Args&&... args) - : isBorrowed_(false) - , own_(std::forward(args)...) {} + : isBorrowed_(false), own_(std::forward(args)...) {} public: - explicit MaybeOwned(): isBorrowed_(true), borrow_() {} + explicit MaybeOwned() : isBorrowed_(true), borrow_() {} // Copying a borrow yields another borrow of the original, as with a // T*. Copying an owned T yields another owned T for safety: no @@ -120,8 +121,9 @@ class MaybeOwned final { return *this; } - MaybeOwned(MaybeOwned&& rhs) noexcept(std::is_nothrow_move_constructible::value) - : isBorrowed_(rhs.isBorrowed_) { + MaybeOwned(MaybeOwned&& rhs) noexcept( + std::is_nothrow_move_constructible::value) + : isBorrowed_(rhs.isBorrowed_) { if (C10_LIKELY(rhs.isBorrowed_)) { MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); } else { @@ -129,15 +131,16 @@ class MaybeOwned final { } } - MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(std::is_nothrow_move_assignable::value) { + MaybeOwned& operator=(MaybeOwned&& rhs) noexcept( + std::is_nothrow_move_assignable::value) { if (this == &rhs) { return *this; } if (C10_UNLIKELY(!isBorrowed_)) { if (rhs.isBorrowed_) { - own_.~T(); - MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); - isBorrowed_ = true; + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; } else { own_ = std::move(rhs.own_); } @@ -158,7 +161,8 @@ class MaybeOwned final { return MaybeOwned(t); } - static MaybeOwned owned(T&& t) noexcept(std::is_nothrow_move_constructible::value) { + static MaybeOwned owned(T&& t) noexcept( + std::is_nothrow_move_constructible::value) { return MaybeOwned(std::move(t)); } @@ -175,22 +179,24 @@ class MaybeOwned final { } } - const T& operator*() const & { + const T& operator*() const& { if (isBorrowed_) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); } return C10_LIKELY(isBorrowed_) - ? MaybeOwnedTraits::referenceFromBorrow(borrow_) - : own_; + ? MaybeOwnedTraits::referenceFromBorrow(borrow_) + : own_; } const T* operator->() const { if (isBorrowed_) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); } return C10_LIKELY(isBorrowed_) - ? MaybeOwnedTraits::pointerFromBorrow(borrow_) - : &own_; + ? MaybeOwnedTraits::pointerFromBorrow(borrow_) + : &own_; } // If borrowed, copy the underlying T. If owned, move from @@ -199,7 +205,8 @@ class MaybeOwned final { // T. T operator*() && { if (isBorrowed_) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); return MaybeOwnedTraits::referenceFromBorrow(borrow_); } else { return std::move(own_); @@ -207,5 +214,4 @@ class MaybeOwned final { } }; - } // namespace c10 diff --git a/c10/util/Metaprogramming.h b/c10/util/Metaprogramming.h index 545240c2745..30f6d7c590a 100644 --- a/c10/util/Metaprogramming.h +++ b/c10/util/Metaprogramming.h @@ -1,25 +1,30 @@ #pragma once -#include +#include +#include #include #include -#include -#include +#include -namespace c10 { namespace guts { +namespace c10 { +namespace guts { /** * Access information about result type or arguments from a function type. * Example: * using A = function_traits::return_type // A == int - * using A = function_traits::parameter_types::tuple_type // A == tuple + * using A = function_traits::parameter_types::tuple_type + * // A == tuple */ -template struct function_traits { - static_assert(!std::is_same::value, "In function_traits, Func must be a plain function type."); +template +struct function_traits { + static_assert( + !std::is_same::value, + "In function_traits, Func must be a plain function type."); }; -template -struct function_traits { - using func_type = Result (Args...); +template +struct function_traits { + using func_type = Result(Args...); using return_type = Result; using parameter_types = typelist::typelist; static constexpr auto number_of_parameters = sizeof...(Args); @@ -33,7 +38,8 @@ struct function_traits { template struct infer_function_traits { - using type = function_traits>; + using type = function_traits< + c10::guts::detail::strip_class_t>; }; template @@ -42,7 +48,7 @@ struct infer_function_traits { }; template -struct infer_function_traits { +struct infer_function_traits { using type = function_traits; }; @@ -56,10 +62,14 @@ using infer_function_traits_t = typename infer_function_traits::type; * Example: * bool f(int, int); * - * infer_function_traits_t == make_function_traits_t> + * infer_function_traits_t == make_function_traits_t> */ -template struct make_function_traits { - static_assert(false_t::value, "In guts::make_function_traits, the ArgList argument must be typelist<...>."); +template +struct make_function_traits { + static_assert( + false_t::value, + "In guts::make_function_traits, the ArgList argument must be typelist<...>."); }; template @@ -68,7 +78,8 @@ struct make_function_traits> { }; template -using make_function_traits_t = typename make_function_traits::type; +using make_function_traits_t = + typename make_function_traits::type; /** * Use extract_arg_by_filtered_index to return the i-th argument whose @@ -77,80 +88,156 @@ using make_function_traits_t = typename make_function_traits::t * Example: * std::string arg1 = "Hello"; * std::string arg2 = "World"; - * std::string&& result = extract_arg_by_filtered_index(0, arg1, 2.0, std::move(arg2)); + * std::string&& result = extract_arg_by_filtered_index(0, + * arg1, 2.0, std::move(arg2)); * - * Warning: Taking the result by rvalue reference can cause segfaults because ownership will not be passed on - * from the original reference. The original reference dies after the expression and the resulting + * Warning: Taking the result by rvalue reference can cause segfaults because + * ownership will not be passed on from the original reference. The original + * reference dies after the expression and the resulting */ namespace detail { -template