diff --git a/.circleci/cimodel/data/binary_build_data.py b/.circleci/cimodel/data/binary_build_data.py index 914645dfab9..ead0240fb80 100644 --- a/.circleci/cimodel/data/binary_build_data.py +++ b/.circleci/cimodel/data/binary_build_data.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - """ This module models the tree of configuration variants for "smoketest" builds. diff --git a/.circleci/cimodel/data/binary_build_definitions.py b/.circleci/cimodel/data/binary_build_definitions.py index 6e848e6ddc1..a2b086dc920 100644 --- a/.circleci/cimodel/data/binary_build_definitions.py +++ b/.circleci/cimodel/data/binary_build_definitions.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from collections import OrderedDict import cimodel.data.binary_build_data as binary_build_data diff --git a/.circleci/cimodel/data/caffe2_build_data.py b/.circleci/cimodel/data/caffe2_build_data.py index 705feb41436..4a9c23e708c 100644 --- a/.circleci/cimodel/data/caffe2_build_data.py +++ b/.circleci/cimodel/data/caffe2_build_data.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from cimodel.lib.conf_tree import ConfigNode, X, XImportant from cimodel.lib.conf_tree import Ver diff --git a/.circleci/cimodel/data/caffe2_build_definitions.py b/.circleci/cimodel/data/caffe2_build_definitions.py index c58492a8e1a..1668094ec99 100644 --- a/.circleci/cimodel/data/caffe2_build_definitions.py +++ b/.circleci/cimodel/data/caffe2_build_definitions.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from collections import OrderedDict import cimodel.data.dimensions as dimensions diff --git a/.circleci/cimodel/data/dimensions.py b/.circleci/cimodel/data/dimensions.py index 9db03aa3c7d..810bf830f4b 100644 --- a/.circleci/cimodel/data/dimensions.py +++ b/.circleci/cimodel/data/dimensions.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - - PHASES = ["build", "test"] CUDA_VERSIONS = [ diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 864f175bb0f..1d3ee8f0517 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from cimodel.lib.conf_tree import ConfigNode, X, XImportant diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index abf0b119561..9455170f485 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from collections import OrderedDict from cimodel.data.pytorch_build_data import TopLevelNode, CONFIG_TREE_DATA diff --git a/.circleci/cimodel/lib/conf_tree.py b/.circleci/cimodel/lib/conf_tree.py index 54c8b15aed1..db2a30f29c2 100644 --- a/.circleci/cimodel/lib/conf_tree.py +++ b/.circleci/cimodel/lib/conf_tree.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - - from dataclasses import dataclass, field from typing import Optional, Dict diff --git a/.circleci/cimodel/lib/miniutils.py b/.circleci/cimodel/lib/miniutils.py index b10fc448dff..7cc3e5ac525 100644 --- a/.circleci/cimodel/lib/miniutils.py +++ b/.circleci/cimodel/lib/miniutils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - - def quote(s): return sandwich('"', s) diff --git a/.circleci/cimodel/lib/miniyaml.py b/.circleci/cimodel/lib/miniyaml.py index ccd888ab2b0..94c3430b6f1 100644 --- a/.circleci/cimodel/lib/miniyaml.py +++ b/.circleci/cimodel/lib/miniyaml.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - - from collections import OrderedDict diff --git a/.circleci/cimodel/lib/visualization.py b/.circleci/cimodel/lib/visualization.py index c583fe4d4a3..10addb54b1c 100644 --- a/.circleci/cimodel/lib/visualization.py +++ b/.circleci/cimodel/lib/visualization.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - """ This module encapsulates dependencies on pygraphviz """ diff --git a/.circleci/scripts/cpp_doc_push_script.sh b/.circleci/scripts/cpp_doc_push_script.sh index c051080ca5e..797914d4ba0 100755 --- a/.circleci/scripts/cpp_doc_push_script.sh +++ b/.circleci/scripts/cpp_doc_push_script.sh @@ -53,7 +53,7 @@ sudo apt-get -y install doxygen # Generate ATen files pushd "${pt_checkout}" pip install -r requirements.txt -time GEN_TO_SOURCE=1 python aten/src/ATen/gen.py \ +time python aten/src/ATen/gen.py \ -s aten/src/ATen \ -d build/aten/src/ATen \ aten/src/ATen/Declarations.cwrap \ diff --git a/.flake8 b/.flake8 index d5d0a454467..37455ad18c9 100644 --- a/.flake8 +++ b/.flake8 @@ -5,10 +5,8 @@ max-line-length = 120 # E501 is not flexible enough, we're using B950 instead ignore = E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # EXE001 is skipped for now because some files use shebang to determine Python version. - EXE001, # these ignores are from flake8-bugbear; please fix! B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411, -exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi +exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi,.git diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 548f16f9c02..ebabe66c553 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -440,6 +440,38 @@ ccache -F 0 # deploy (and add to ~/.bashrc for later) export PATH="/usr/lib/ccache:$PATH" ``` + +It is also possible to install `ccache` via `conda` by installing it from the +community-maintained `conda-forge` channel. Here is how to set up `ccache` this +way: + +```bash +# install ccache +conda install -c conda-forge ccache + +# set up ccache compiler symlinks +mkdir ~/ccache +mkdir ~/ccache/lib +mkdir ~/ccache/cuda +ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/cc +ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/c++ +ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/gcc +ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/g++ +ln -s $CONDA_PREFIX/bin/ccache ~/ccache/cuda/nvcc + +# update PATH to reflect symlink locations, consider +# adding this to your .bashrc +export PATH=~/ccache/lib:$PATH +export CUDA_NVCC_EXECUTABLE=~/ccache/cuda/nvcc + +# increase ccache cache size to 25 GiB +ccache -M 25Gi +``` + +To check this is working, do two clean builds of pytorch in a row. The second +build should be substantially and noticeably faster than the first build. + + #### Use a faster linker If you are editing a single file and rebuilding in a tight loop, the time spent linking will dominate. The system linker available in most Linux distributions diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 84b2fb0a6d9..55d15d09d55 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -146,10 +146,24 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexHalf, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + + #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index 75d186f4c11..55eed44b907 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace at { @@ -31,6 +32,12 @@ inline C10_HOST_DEVICE bool _isnan(T val) { #endif } +template ::value, int>::type = 0> +inline bool _isnan(T val) { + return std::isnan(std::real(val)) || std::isnan(std::imag(val)); +} + inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { return at::_isnan(float(val)); } diff --git a/aten/src/ATen/core/ATenDispatch.h b/aten/src/ATen/core/ATenDispatch.h index 69fa23cac75..37498d84cc1 100644 --- a/aten/src/ATen/core/ATenDispatch.h +++ b/aten/src/ATen/core/ATenDispatch.h @@ -40,7 +40,8 @@ namespace impl { // question is whether or not we have access to all the relevant TLS at this // point. static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) { - return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId(); + c10::impl::LocalTensorTypeSet local = c10::impl::tls_local_tensor_type_set(); + return ((ts | local.included_) - local.excluded_).highestPriorityTypeId(); } } diff --git a/aten/src/ATen/core/LegacyTypeDispatch.h b/aten/src/ATen/core/LegacyTypeDispatch.h index e034e1b2b36..de47fff5c6d 100644 --- a/aten/src/ATen/core/LegacyTypeDispatch.h +++ b/aten/src/ATen/core/LegacyTypeDispatch.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -47,22 +48,19 @@ class CAFFE2_API LegacyTypeDispatch { CAFFE2_API LegacyTypeDispatch& globalLegacyTypeDispatch(); -// A RAII, thread local (!) guard that has the following effect: -// -// Upon construction: sets NonVariableTypeMode_enabled for the current thread to -// control whether we are in non-Variable-type mode. -// -// Upon destruction: sets NonVariableTypeMode_enabled back to the original value. +// A RAII, thread local (!) guard that will disable dispatch to variable +// handler. // // See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. struct CAFFE2_API AutoNonVariableTypeMode { - AutoNonVariableTypeMode(bool enabled) : prev_mode(NonVariableTypeMode::is_enabled()) { - NonVariableTypeMode::set_enabled(enabled); + // NB: The enabled parameter must ALWAYS be black, as Henry Ford used to say. + // TODO: Eliminate this parameter entirely + AutoNonVariableTypeMode(bool enabled = true) : + guard_(TensorTypeId::VariableTensorId) { + + TORCH_INTERNAL_ASSERT(enabled); } - ~AutoNonVariableTypeMode() { - NonVariableTypeMode::set_enabled(prev_mode); - } - bool prev_mode; + c10::impl::ExcludeTensorTypeIdGuard guard_; }; } // namespace at diff --git a/aten/src/ATen/cpu/vec256/vec256.h b/aten/src/ATen/cpu/vec256/vec256.h index b7a4abbf0b6..db89c243276 100644 --- a/aten/src/ATen/cpu/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec256/vec256.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index 3b9eb5a8f4c..81da4510d5f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -12,6 +12,7 @@ #include #include #include +#include #if defined(__GNUC__) #define __at_align32__ __attribute__((aligned(32))) @@ -169,12 +170,19 @@ public: } return ret; } - template ::value, int>::type = 0> + Vec256 map(T (*f)(const T &)) const { + Vec256 ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + template ::value && !std::is_complex_t::value, int>::type = 0> Vec256 abs() const { - // non_float_t is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same::value, "non_float_t must be T"); - return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); + // other_t is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same::value, "other_t must be T"); + return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } template ::value, int>::type = 0> @@ -185,6 +193,26 @@ public: // 0.0) properly. return map(std::abs); } + template ::value, int>::type = 0> + Vec256 abs() const { + // complex_t is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same::value, "complex_t must be T"); + // Specifically map() does not perform the type conversion needed by abs. + return map([](T x) { return (T)std::abs(x); }); + } + Vec256 angle() const { + return *this; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return *this; + } + Vec256 conj() const { + return *this; + } Vec256 acos() const { return map(std::acos); } @@ -232,7 +260,7 @@ public: return map(std::log2); } Vec256 ceil() const { - return map(std::ceil); + return map(at::native::ceil_impl); } Vec256 cos() const { return map(std::cos); @@ -241,7 +269,7 @@ public: return map(std::cosh); } Vec256 floor() const { - return map(std::floor); + return map(at::native::floor_impl); } Vec256 neg() const { // NB: the trailing return type is needed because we need to coerce the @@ -251,7 +279,7 @@ public: } Vec256 round() const { // We do not use std::round because we would like to round midway numbers to the nearest even integer. - return map(std::nearbyint); + return map(at::native::round_impl); } Vec256 sin() const { return map(std::sin); @@ -266,7 +294,7 @@ public: return map(std::tanh); } Vec256 trunc() const { - return map(std::trunc); + return map(at::native::trunc_impl); } Vec256 lgamma() const { return map(std::lgamma); @@ -278,7 +306,7 @@ public: return map([](T x) { return (T)(1) / x; }); } Vec256 rsqrt() const { - return map([](T x) { return 1 / std::sqrt(x); }); + return map([](T x) { return (T)1 / std::sqrt(x); }); } Vec256 pow(const Vec256 &exp) const { Vec256 ret; @@ -352,7 +380,9 @@ template Vec256 inline operator||( // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. -template Vec256 inline maximum(const Vec256 &a, const Vec256 &b) { +template ::value, int>::type = 0> +Vec256 inline maximum(const Vec256 &a, const Vec256 &b) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; @@ -366,6 +396,22 @@ template Vec256 inline maximum(const Vec256 &a, const Vec256 return c; } +template ::value, int>::type = 0> +Vec256 inline maximum(const Vec256 &a, const Vec256 &b) { + Vec256 c = Vec256(); + for (int i = 0; i != Vec256::size(); i++) { + c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + template inline T maximum(const T& a, const T& b) { T c = (a > b) ? a : b; @@ -377,7 +423,9 @@ inline T maximum(const T& a, const T& b) { // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. -template Vec256 inline minimum(const Vec256 &a, const Vec256 &b) { +template ::value, int>::type = 0> +Vec256 inline minimum(const Vec256 &a, const Vec256 &b) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; @@ -391,6 +439,22 @@ template Vec256 inline minimum(const Vec256 &a, const Vec256 return c; } +template ::value, int>::type = 0> +Vec256 inline minimum(const Vec256 &a, const Vec256 &b) { + Vec256 c = Vec256(); + for (int i = 0; i != Vec256::size(); i++) { + c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + template inline T minimum(const T& a, const T& b) { T c = (a < b) ? a : b; @@ -401,7 +465,9 @@ inline T minimum(const T& a, const T& b) { } // To save BC, it will not propagate NaN based on IEEE 754 201X -template Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { +template ::value, int>::type = 0> +Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : (a[i] > max_vec[i] ? max_vec[i] : a[i]); @@ -409,7 +475,19 @@ template Vec256 inline clamp(const Vec256 &a, const Vec256 &m return c; } -template Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { +template ::value, int>::type = 0> +Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { + Vec256 c = Vec256(); + for (int i = 0; i != Vec256::size(); i++) { + c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : (std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]); + } + return c; +} + +template ::value, int>::type = 0> +Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; @@ -417,7 +495,19 @@ template Vec256 inline clamp_max(const Vec256 &a, const Vec256 Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { +template ::value, int>::type = 0> +Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { + Vec256 c = Vec256(); + for (int i = 0; i != Vec256::size(); i++) { + c[i] = std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]; + } + return c; +} + +template ::value, int>::type = 0> +Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; @@ -425,6 +515,16 @@ template Vec256 inline clamp_min(const Vec256 &a, const Vec256::value, int>::type = 0> +Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { + Vec256 c = Vec256(); + for (int i = 0; i != Vec256::size(); i++) { + c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : a[i]; + } + return c; +} + #define DEFINE_BITWISE_OP(op) \ template \ Vec256 inline operator op(const Vec256 &a, const Vec256 &b) { \ diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h new file mode 100644 index 00000000000..ce4162a4964 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -0,0 +1,369 @@ +#pragma once + +#include +#include +#if defined(__AVX__) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(__AVX__) && !defined(_MSC_VER) + +template <> class Vec256> { +private: + __m256d values; +public: + using value_type = std::complex; + static constexpr int size() { + return 2; + } + Vec256() {} + Vec256(__m256d v) : values(v) {} + Vec256(std::complex val) { + double real_value = std::real(val); + double imag_value = std::imag(val); + values = _mm256_setr_pd(real_value, imag_value, + real_value, imag_value); + } + Vec256(std::complex val1, std::complex val2) { + values = _mm256_setr_pd(std::real(val1), std::imag(val1), + std::real(val2), std::imag(val2)); + } + operator __m256d() const { + return values; + } + template + static Vec256> blend(const Vec256>& a, const Vec256>& b) { + // convert std::complex index mask to V index mask: xy -> xxyy + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_pd(a.values, b.values, 0x03); + case 2: + return _mm256_blend_pd(a.values, b.values, 0x0c); + } + return b; + } + static Vec256> blendv(const Vec256>& a, const Vec256>& b, + const Vec256>& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); + return _mm256_blendv_pd(a.values, b.values, mask_); + + } + static Vec256> arange(std::complex base = 0., std::complex step = 1.) { + return Vec256>(base, + base + step); + } + static Vec256> set(const Vec256>& a, const Vec256>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + static Vec256> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align32__ double tmp_values[2*size()]; + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(std::complex)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2*size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(std::complex)); + } + } + const std::complex& operator[](int idx) const = delete; + std::complex& operator[](int idx) = delete; + Vec256> map(std::complex (*f)(const std::complex &)) const { + __at_align32__ std::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256d abs_2_() const { + auto val_2 = _mm256_mul_pd(values, values); // a*a b*b + return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m256d abs_() const { + return _mm256_sqrt_pd(abs_2_()); // abs abs + } + Vec256> abs() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm256_and_pd(abs_(), real_mask); // abs 0 + } + __m256d angle_() const { + //angle = atan2(b/a) + auto b_a = _mm256_permute_pd(values, 0x05); // b a + return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + } + Vec256> angle() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle + return _mm256_and_pd(angle, real_mask); // angle 0 + } + __m256d real_() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm256_and_pd(values, real_mask); + } + Vec256> real() const { + return real_(); + } + __m256d imag_() const { + const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); + return _mm256_and_pd(values, imag_mask); + } + Vec256> imag() const { + return _mm256_permute_pd(imag_(), 0x05); //b a + } + __m256d conj_() const { + const __m256d conj_mask = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + return _mm256_mul_pd(values, conj_mask); //a -b + } + Vec256> conj() const { + return conj_(); + } + Vec256> acos() const { + return map(std::acos); + } + Vec256> asin() const { + return map(std::asin); + } + Vec256> atan() const { + return map(std::atan); + } + Vec256> atan2(const Vec256> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> exp() const { + return map(std::exp); + } + Vec256> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> log() const { + return map(std::log); + } + Vec256> log2() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> log10() const { + return map(std::log10); + } + Vec256> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> sin() const { + return map(std::sin); + } + Vec256> sinh() const { + return map(std::sinh); + } + Vec256> cos() const { + return map(std::cos); + } + Vec256> cosh() const { + return map(std::cosh); + } + Vec256> ceil() const { + return _mm256_ceil_pd(values); + } + Vec256> floor() const { + return _mm256_floor_pd(values); + } + Vec256> neg() const { + auto zero = _mm256_setzero_pd(); + return _mm256_sub_pd(zero, values); + } + Vec256> round() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vec256> tan() const { + return map(std::tan); + } + Vec256> tanh() const { + return map(std::tanh); + } + Vec256> trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vec256> sqrt() const { + return map(std::sqrt); + } + Vec256> reciprocal() const; + Vec256> rsqrt() const { + return map([](const std::complex &x) { return (std::complex)(1)/std::sqrt(x); }); + } + Vec256> pow(const Vec256> &exp) const { + AT_ERROR("not supported for complex numbers"); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vec256> operator==(const Vec256>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + Vec256> operator!=(const Vec256>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ); + } + Vec256> operator<(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator<=(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator>(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator>=(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } +}; + +template <> Vec256> inline operator+(const Vec256> &a, const Vec256> &b) { + return _mm256_add_pd(a, b); +} + +template <> Vec256> inline operator-(const Vec256> &a, const Vec256> &b) { + return _mm256_sub_pd(a, b); +} + +template <> Vec256> inline operator*(const Vec256> &a, const Vec256> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256d neg = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + auto ac_bd = _mm256_mul_pd(a, b); //ac bd + + auto d_c = _mm256_permute_pd(b, 0x05); //d c + d_c = _mm256_mul_pd(neg, d_c); //d -c + auto ad_bc = _mm256_mul_pd(a, d_c); //ad -bc + + auto ret = _mm256_hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vec256> inline operator/(const Vec256> &a, const Vec256> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m256d neg = _mm256_setr_pd(-1.0, 1.0, -1.0, 1.0); + auto ac_bd = _mm256_mul_pd(a, b); //ac bd + + auto d_c = _mm256_permute_pd(b, 0x05); //d c + d_c = _mm256_mul_pd(neg, d_c); //-d c + auto ad_bc = _mm256_mul_pd(a, d_c); //-ad bc + + auto re_im = _mm256_hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad + return _mm256_div_pd(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vec256> Vec256>::reciprocal() const{ + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m256d neg = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + auto c_d = _mm256_mul_pd(neg, values); //c -d + return _mm256_div_pd(c_d, abs_2_()); +} + +template <> +Vec256> inline maximum(const Vec256>& a, const Vec256>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(max, isnan); +} + +template <> +Vec256> inline minimum(const Vec256>& a, const Vec256>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(min, isnan); +} + +template <> +Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { + auto abs_a = a.abs_2_(); + auto abs_min = min.abs_2_(); + auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); + auto abs_max = max.abs_2_(); + auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); + return _mm256_blendv_pd(_mm256_blendv_pd(a, min, max_mask), max, min_mask); +} + +template <> +Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { + auto abs_a = a.abs_2_(); + auto abs_min = min.abs_2_(); + auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); + return _mm256_blendv_pd(a, min, max_mask); +} + +template <> +Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { + auto abs_a = a.abs_2_(); + auto abs_max = max.abs_2_(); + auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); + return _mm256_blendv_pd(a, max, min_mask); +} + +template <> +Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vec256> inline operator|(const Vec256>& a, const Vec256>& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vec256> inline operator^(const Vec256>& a, const Vec256>& b) { + return _mm256_xor_pd(a, b); +} + +#ifdef __AVX2__ +template <> inline Vec256> fmadd(const Vec256>& a, const Vec256>& b, const Vec256>& c) { + return a * b + c; +} +#endif + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h new file mode 100644 index 00000000000..2e388914f2b --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -0,0 +1,405 @@ +#pragma once + +#include +#include +#if defined(__AVX__) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(__AVX__) && !defined(_MSC_VER) + +template <> class Vec256> { +private: + __m256 values; +public: + using value_type = std::complex; + static constexpr int size() { + return 4; + } + Vec256() {} + Vec256(__m256 v) : values(v) {} + Vec256(std::complex val) { + float real_value = std::real(val); + float imag_value = std::imag(val); + values = _mm256_setr_ps(real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value + ); + } + Vec256(std::complex val1, std::complex val2, std::complex val3, std::complex val4) { + values = _mm256_setr_ps(std::real(val1), std::imag(val1), + std::real(val2), std::imag(val2), + std::real(val3), std::imag(val3), + std::real(val4), std::imag(val4) + ); + } + operator __m256() const { + return values; + } + template + static Vec256> blend(const Vec256>& a, const Vec256>& b) { + // convert std::complex index mask to V index mask: xy -> xxyy + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011 + case 2: + return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100 + case 3: + return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111 + case 4: + return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000 + case 5: + return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011 + case 6: + return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100 + case 7: + return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111 + case 8: + return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000 + case 9: + return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011 + case 10: + return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100 + case 11: + return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111 + case 12: + return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000 + case 13: + return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011 + case 14: + return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100 + } + return b; + } + static Vec256> blendv(const Vec256>& a, const Vec256>& b, + const Vec256>& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); + return _mm256_blendv_ps(a.values, b.values, mask_); + + } + static Vec256> arange(std::complex base = 0., std::complex step = 1.) { + return Vec256>(base, + base + step, + base + std::complex(2)*step, + base + std::complex(3)*step); + } + static Vec256> set(const Vec256>& a, const Vec256>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vec256> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + + __at_align32__ float tmp_values[2*size()]; + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(std::complex)); + return _mm256_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2*size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(std::complex)); + } + } + const std::complex& operator[](int idx) const = delete; + std::complex& operator[](int idx) = delete; + Vec256> map(std::complex (*f)(const std::complex &)) const { + __at_align32__ std::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256 abs_2_() const { + auto val_2 = _mm256_mul_ps(values, values); // a*a b*b + return _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + } + __m256 abs_() const { + return _mm256_sqrt_ps(abs_2_()); // abs abs + } + Vec256> abs() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm256_and_ps(abs_(), real_mask); // abs 0 + } + __m256 angle_() const { + //angle = atan2(b/a) + auto b_a = _mm256_permute_ps(values, 0x55); // b a + return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + } + Vec256> angle() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + auto angle = _mm256_permute_ps(angle_(), 0x55); // angle 90-angle + return _mm256_and_ps(angle, real_mask); // angle 0 + } + __m256 real_() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm256_and_ps(values, real_mask); + } + Vec256> real() const { + return real_(); + } + __m256 imag_() const { + const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF)); + return _mm256_and_ps(values, imag_mask); + } + Vec256> imag() const { + return _mm256_permute_ps(imag_(), 0x55); //b a + } + __m256 conj_() const { + const __m256 conj_mask = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + return _mm256_mul_ps(values, conj_mask); //a -b + } + Vec256> conj() const { + return conj_(); + } + Vec256> acos() const { + return map(std::acos); + } + Vec256> asin() const { + return map(std::asin); + } + Vec256> atan() const { + return map(std::atan); + } + Vec256> atan2(const Vec256> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> exp() const { + return map(std::exp); + } + Vec256> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> log() const { + return map(std::log); + } + Vec256> log2() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> log10() const { + return map(std::log10); + } + Vec256> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> sin() const { + return map(std::sin); + } + Vec256> sinh() const { + return map(std::sinh); + } + Vec256> cos() const { + return map(std::cos); + } + Vec256> cosh() const { + return map(std::cosh); + } + Vec256> ceil() const { + return _mm256_ceil_ps(values); + } + Vec256> floor() const { + return _mm256_floor_ps(values); + } + Vec256> neg() const { + auto zero = _mm256_setzero_ps(); + return _mm256_sub_ps(zero, values); + } + Vec256> round() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vec256> tan() const { + return map(std::tan); + } + Vec256> tanh() const { + return map(std::tanh); + } + Vec256> trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vec256> sqrt() const { + return map(std::sqrt); + } + Vec256> reciprocal() const; + Vec256> rsqrt() const { + return map([](const std::complex &x) { return (std::complex)(1)/std::sqrt(x); }); + } + Vec256> pow(const Vec256> &exp) const { + AT_ERROR("not supported for complex numbers"); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vec256> operator==(const Vec256>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + Vec256> operator!=(const Vec256>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ); + } + Vec256> operator<(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator<=(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator>(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> operator>=(const Vec256>& other) const { + AT_ERROR("not supported for complex numbers"); + } +}; + +template <> Vec256> inline operator+(const Vec256> &a, const Vec256> &b) { + return _mm256_add_ps(a, b); +} + +template <> Vec256> inline operator-(const Vec256> &a, const Vec256> &b) { + return _mm256_sub_ps(a, b); +} + +template <> Vec256> inline operator*(const Vec256> &a, const Vec256> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256 neg = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + auto ac_bd = _mm256_mul_ps(a, b); //ac bd + + auto d_c = _mm256_permute_ps(b, 0x55); //d c + d_c = _mm256_mul_ps(neg, d_c); //d -c + auto ad_bc = _mm256_mul_ps(a, d_c); //ad -bc + + auto ret = _mm256_hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vec256> inline operator/(const Vec256> &a, const Vec256> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m256 neg = _mm256_setr_ps(-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0); + auto ac_bd = _mm256_mul_ps(a, b); //ac bd + + auto d_c = _mm256_permute_ps(b, 0x05); //d c + d_c = _mm256_mul_ps(neg, d_c); //-d c + auto ad_bc = _mm256_mul_ps(a, d_c); //-ad bc + + auto re_im = _mm256_hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad + return _mm256_div_ps(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vec256> Vec256>::reciprocal() const { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m256 neg = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + auto c_d = _mm256_mul_ps(neg, values); //c -d + return _mm256_div_ps(c_d, abs_2_()); +} + +template <> +Vec256> inline maximum(const Vec256>& a, const Vec256>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(max, isnan); +} + +template <> +Vec256> inline minimum(const Vec256>& a, const Vec256>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(min, isnan); +} + +template <> +Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { + auto abs_a = a.abs_2_(); + auto abs_min = min.abs_2_(); + auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); + auto abs_max = max.abs_2_(); + auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); + return _mm256_blendv_ps(_mm256_blendv_ps(a, min, max_mask), max, min_mask); +} + +template <> +Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { + auto abs_a = a.abs_2_(); + auto abs_min = min.abs_2_(); + auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); + return _mm256_blendv_ps(a, min, max_mask); +} + +template <> +Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { + auto abs_a = a.abs_2_(); + auto abs_max = max.abs_2_(); + auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); + return _mm256_blendv_ps(a, max, min_mask); +} + +template <> +Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vec256> inline operator|(const Vec256>& a, const Vec256>& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vec256> inline operator^(const Vec256>& a, const Vec256>& b) { + return _mm256_xor_ps(a, b); +} + +#ifdef __AVX2__ +template <> inline Vec256> fmadd(const Vec256>& a, const Vec256>& b, const Vec256>& c) { + return a * b + c; +} +#endif + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index 7a570f9dd18..5bb8028a88f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -91,6 +91,18 @@ public: auto mask = _mm256_set1_pd(-0.f); return _mm256_andnot_pd(mask, values); } + Vec256 angle() const { + return _mm256_set1_pd(0); + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return _mm256_set1_pd(0); + } + Vec256 conj() const { + return *this; + } Vec256 acos() const { return Vec256(Sleef_acosd4_u10(values)); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index 564328267ac..67aa83dce16 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -99,6 +99,18 @@ public: auto mask = _mm256_set1_ps(-0.f); return _mm256_andnot_ps(mask, values); } + Vec256 angle() const { + return _mm256_set1_ps(0); + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return _mm256_set1_ps(0); + } + Vec256 conj() const { + return *this; + } Vec256 acos() const { return Vec256(Sleef_acosf8_u10(values)); } diff --git a/aten/src/ATen/cpu/vec256/vec256_int.h b/aten/src/ATen/cpu/vec256/vec256_int.h index a3c56113290..33ae0f93892 100644 --- a/aten/src/ATen/cpu/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec256/vec256_int.h @@ -97,6 +97,19 @@ struct Vec256 : public Vec256i { auto inverse = _mm256_xor_si256(values, is_larger); return _mm256_sub_epi64(inverse, is_larger); } + Vec256 angle() const { + return _mm256_set1_epi64x(0); + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return _mm256_set1_epi64x(0); + } + Vec256 conj() const { + return *this; + } + Vec256 frac() const; Vec256 neg() const; Vec256 operator==(const Vec256& other) const { return _mm256_cmpeq_epi64(values, other.values); @@ -194,6 +207,19 @@ struct Vec256 : public Vec256i { Vec256 abs() const { return _mm256_abs_epi32(values); } + Vec256 angle() const { + return _mm256_set1_epi32(0); + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return _mm256_set1_epi32(0); + } + Vec256 conj() const { + return *this; + } + Vec256 frac() const; Vec256 neg() const; Vec256 operator==(const Vec256& other) const { return _mm256_cmpeq_epi32(values, other.values); @@ -380,6 +406,19 @@ struct Vec256 : public Vec256i { Vec256 abs() const { return _mm256_abs_epi16(values); } + Vec256 angle() const { + return _mm256_set1_epi16(0); + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return _mm256_set1_epi16(0); + } + Vec256 conj() const { + return *this; + } + Vec256 frac() const; Vec256 neg() const; Vec256 operator==(const Vec256& other) const { return _mm256_cmpeq_epi16(values, other.values); diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index 1dfbfc6dc4a..45523166b9a 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -39,9 +39,10 @@ // https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall // when using AVX/AVX2 code resolves this. #if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23 -#define DL_RUNTIME_BUG(op, type) \ - volatile type x = (type)(1); \ - x = std::op(x); \ +#define DL_RUNTIME_BUG(op, type) \ + using value_t = typename at::native::ztype::value_t; \ + volatile value_t x = (value_t)(1); \ + x = std::op(x); \ _mm256_zeroall(); #else #define DL_RUNTIME_BUG(op, type) diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 82654f9a2d5..4638077c9b1 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -55,53 +55,67 @@ bool _nnpack_available() { #include "nnpack.h" -#include - -#include -#include +#include "caffe2/utils/threadpool/ThreadPoolMobile.h" namespace at { namespace native { -// Stolen from Caffe2 -static pthreadpool_t nnpack_threadpool_ = nullptr; -static bool called_nnpack_threadpool_ = false; +static bool init_nnpack() { + static std::once_flag once_; + static bool nnpack_successfully_initialized_ = false; + + std::call_once(once_, []() { + const nnp_status nnpack_status = nnp_initialize(); + nnpack_successfully_initialized_ = (nnp_status_success == nnpack_status); -pthreadpool_t nnpack_threadpool() { - if (! called_nnpack_threadpool_) { - called_nnpack_threadpool_ = true; - enum nnp_status nnpack_status = nnp_initialize(); if (nnpack_status != nnp_status_success) { if (nnpack_status == nnp_status_out_of_memory) { - throw std::runtime_error("could not initialize NNPack (out of memory)"); + LOG(WARNING) << "Could not initialize NNPACK! Reason: Out of memory."; } else if (nnpack_status == nnp_status_unsupported_hardware) { - throw std::runtime_error("could not initialize NNPack (unsupported hardware)"); + LOG(WARNING) << "Could not initialize NNPACK! Reason: Unsupported hardware."; } else { - throw std::runtime_error("could not initialize NNPack (unknown error)"); + LOG(WARNING) << "Could not initialize NNPACK! Reason: Unknown error!"; } } - unsigned int threads; -#ifdef INTRA_OP_PARALLEL - threads = at::get_num_threads(); + }); + + return nnpack_successfully_initialized_; +} + +static pthreadpool_t nnpack_threadpool() { + // Try initializing a threadpool for NNPACK's use. If we fail to + // successfully initialize an implementation, return nullptr which will + // instruct NNPACK to run single threaded. + +#ifdef C10_MOBILE + // If building for mobile, use Caffe 2's mobile-friendly threadpool. + return caffe2::mobile_pthreadpool(); #else - threads = std::thread::hardware_concurrency(); + // Otherwise, try using pthreadpool if we manage to initialize it successfully. + static pthreadpool_t nnpack_threadpool_ = nullptr; + static bool called_nnpack_threadpool_ = false; + + if (!called_nnpack_threadpool_) { + called_nnpack_threadpool_ = true; + +#ifdef INTRA_OP_PARALLEL + const uint32_t threads = at::get_num_threads(); +#else + const uint32_t threads = std::thread::hardware_concurrency(); #endif + nnpack_threadpool_ = pthreadpool_create(threads); - if (nnpack_threadpool_ == nullptr) { - throw std::runtime_error("could not initialize NNPack's pthreadpool"); + if (!nnpack_threadpool_) { + LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode."; } } + return nnpack_threadpool_; +#endif } bool _nnpack_available() { - if (! called_nnpack_threadpool_) { - try { - return nnpack_threadpool() != nullptr; - } catch (std::runtime_error e) { - } - } - return nnpack_threadpool() != nullptr; + return init_nnpack(); } // Make thread_local for safety in cases where we have multiple threads running diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 3b26cd85f26..3988b5b866d 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -246,7 +246,8 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { return at::op##_out(self, self); \ } \ Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \ - checkBackend(#op, result, Backend::device); \ + checkDeviceType(#op, result, DeviceType::device); \ + checkLayout(#op, result, Layout::Strided); \ auto iter = TensorIterator::unary_op(result, self, \ /*check_mem_overlap=*/true); \ op##_stub(iter.device_type(), iter); \ @@ -263,6 +264,10 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA) IMPLEMENT_UNARY_OP_VEC(abs) +IMPLEMENT_UNARY_OP_VEC(angle) +IMPLEMENT_UNARY_OP_VEC(real) +IMPLEMENT_UNARY_OP_VEC(imag) +IMPLEMENT_UNARY_OP_VEC(conj) IMPLEMENT_UNARY_OP_VEC(acos) IMPLEMENT_UNARY_OP_VEC(asin) IMPLEMENT_UNARY_OP_VEC(atan) @@ -285,6 +290,10 @@ IMPLEMENT_UNARY_OP_VEC(tanh) IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma) DEFINE_DISPATCH(abs_stub); +DEFINE_DISPATCH(angle_stub); +DEFINE_DISPATCH(real_stub); +DEFINE_DISPATCH(imag_stub); +DEFINE_DISPATCH(conj_stub); DEFINE_DISPATCH(acos_stub); DEFINE_DISPATCH(asin_stub); DEFINE_DISPATCH(atan_stub); diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index 0640bca5fd2..06cc3f39b1c 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -13,6 +13,10 @@ using unary_fn = void(*)(TensorIterator&); using unary_fn_with_scalar = void(*)(TensorIterator&, Scalar a); DECLARE_DISPATCH(unary_fn, abs_stub); +DECLARE_DISPATCH(unary_fn, angle_stub); +DECLARE_DISPATCH(unary_fn, real_stub); +DECLARE_DISPATCH(unary_fn, imag_stub); +DECLARE_DISPATCH(unary_fn, conj_stub); DECLARE_DISPATCH(unary_fn, acos_stub); DECLARE_DISPATCH(unary_fn, asin_stub); DECLARE_DISPATCH(unary_fn, atan_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index c1437bac680..73f8bdd27a5 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -14,14 +14,13 @@ namespace { using namespace vec256; void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { - if (iter.dtype() == ScalarType::Bool || isComplexType(iter.dtype())) { - AT_DISPATCH_COMPLEX_TYPES_AND(kBool, iter.dtype(), "add_cpu/sub_cpu", [&]() { + if (iter.dtype() == ScalarType::Bool) { + using scalar_t = bool; auto alpha = alpha_scalar.to(); cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }); - }); } else { - AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "add_cpu/sub_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "add_cpu/sub_cpu", [&]() { auto alpha = alpha_scalar.to(); auto alpha_vec = Vec256(alpha); cpu_kernel_vec(iter, @@ -51,13 +50,8 @@ void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) { void mul_kernel(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; }); - } else if (isComplexType(iter.dtype())) { - AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "mul_cpu", [&]() { - cpu_kernel(iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }); - }); } else { - AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "mul_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "mul_cpu", [&]() { cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, [=](Vec256 a, Vec256 b) { @@ -78,9 +72,12 @@ void div_kernel(TensorIterator& iter) { }); } else if (isComplexType(iter.dtype())) { AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "div_cpu", [&]() { - cpu_kernel(iter, + cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return a / b; + }, + [=](Vec256 a, Vec256 b) { + return a / b; }); }); } else { diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 5d2ae2020c9..ef25328fda9 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -42,13 +42,13 @@ namespace at { namespace native { namespace { using namespace vec256; -template +template typename traits::ArgsTuple dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, - c10::guts::index_sequence) { + c10::guts::index_sequence) { return std::make_tuple( - *(typename traits::template arg::type*) - (data[I] + i * strides[I])...); + *(typename traits::template arg::type*) + (data[INDEX] + i * strides[INDEX])...); } template @@ -58,19 +58,19 @@ dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) { return dereference_impl(data, strides, i, Indices{}); } -template +template typename traits::ArgsTuple dereference_vec_impl(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i, - c10::guts::index_sequence) { + c10::guts::index_sequence) { using Vec = typename traits::result_type; using scalar_t = typename Vec::value_type; return std::make_tuple( - S == I + 1 ? + S == INDEX + 1 ? opt_scalar : - Vec::loadu(data[I] + i * sizeof(scalar_t))...); + Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...); } template @@ -171,15 +171,15 @@ static inline void unroll_contiguous_scalar_checks( cb(0); } -template +template static inline void unroll_contiguous_scalar_checks( const int64_t* strides, - c10::guts::index_sequence, + c10::guts::index_sequence, const cb_t& cb) { - if (is_contiguous_scalar(strides)) { - cb(I0 + 1); + if (is_contiguous_scalar(strides)) { + cb(INDEX0 + 1); } else { - unroll_contiguous_scalar_checks(strides, c10::guts::index_sequence{}, cb); + unroll_contiguous_scalar_checks(strides, c10::guts::index_sequence{}, cb); } } diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index ddc3d55da49..33e1b5574b3 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -16,6 +16,7 @@ #include #include +#include #include @@ -29,10 +30,10 @@ namespace { using namespace vec256; static void sigmoid_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "sigmoid_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sigmoid_cpu", [&]() { cpu_kernel_vec( iter, - [=](scalar_t a) -> scalar_t { return (1 / (1 + std::exp((-a)))); }, + [=](scalar_t a) -> scalar_t { return ((scalar_t)(1) / ((scalar_t)(1) + std::exp((-a)))); }, [=](Vec256 a) { a = Vec256((scalar_t)(0)) - a; a = a.exp(); @@ -53,7 +54,7 @@ uint8_t abs_impl(uint8_t v) { } static void abs_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "abs_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "abs_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return abs_impl(a); }, @@ -61,6 +62,42 @@ static void abs_kernel(TensorIterator& iter) { }); } +static void angle_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "angle_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return angle_impl(a); }, + [=](Vec256 a) { return a.angle(); }); + }); +} + +static void real_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "real_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return real_impl(a); }, + [=](Vec256 a) { return a.real(); }); + }); +} + +static void imag_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "imag_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return imag_impl(a); }, + [=](Vec256 a) { return a.imag(); }); + }); +} + +static void conj_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "conj_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return conj_impl(a); }, + [=](Vec256 a) { return a.conj(); }); + }); +} + static void bitwise_not_kernel(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { // Boolean type does not work with ~ (bitwise NOT) in C++. bitwise_not wraps this operation for both Boolean and @@ -100,7 +137,7 @@ static void logical_not_kernel(TensorIterator& iter) { } static void reciprocal_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "reciprocal_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "reciprocal_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return decltype(a)(1.0) / a; }, @@ -109,7 +146,7 @@ static void reciprocal_kernel(TensorIterator& iter) { } static void neg_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "neg_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "neg_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return -a; }, @@ -141,7 +178,7 @@ static void sign_kernel(TensorIterator& iter){ } static void sinh_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "sinh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sinh_cpu", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return std::sinh(a); }); @@ -149,7 +186,7 @@ static void sinh_kernel(TensorIterator& iter) { } static void cosh_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "cosh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "cosh_cpu", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return std::cosh(a); }); @@ -181,33 +218,36 @@ static void polygamma_kernel(TensorIterator& iter, int64_t n) { } static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "clamp_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "clamp_cpu", [&]() { + ztype::value_t (*zabs_)(scalar_t) = zabs; auto min = min_scalar.to(); auto max = max_scalar.to(); auto min_vec = Vec256(min); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return a < min ? min : (a > max ? max : a); }, + [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : (zabs_(a) > zabs_(max) ? max : a); }, [=](Vec256 a) { return vec256::clamp(a, min_vec, max_vec); }); }); } static void clamp_max_kernel(TensorIterator& iter, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "clamp_max_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "clamp_max_cpu", [&]() { + ztype::value_t (*zabs_)(scalar_t) = zabs; auto max = max_scalar.to(); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return a > max ? max : a; }, + [=](scalar_t a) -> scalar_t { return zabs_(a) > zabs_(max) ? max : a; }, [=](Vec256 a) { return vec256::clamp_max(a, max_vec); }); }); } static void clamp_min_kernel(TensorIterator& iter, Scalar min_scalar) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "clamp_min_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "clamp_min_cpu", [&]() { + ztype::value_t (*zabs_)(scalar_t) = zabs; auto min = min_scalar.to(); auto min_vec = Vec256(min); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return a < min ? min : a; }, + [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : a; }, [=](Vec256 a) { return vec256::clamp_min(a, min_vec); }); }); } @@ -272,7 +312,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) { #endif static void rsqrt_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rsqrt_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "rsqrt_cpu", [&] { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { @@ -315,12 +355,47 @@ static void rsqrt_kernel(TensorIterator& iter) { } \ REGISTER_DISPATCH(op##_stub, &op##_kernel) +#define IMPLEMENT_COMPLEX_KERNEL(dispatchtypes, op) \ + static void op##_kernel(TensorIterator& iter) { \ + TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), op##_vml_cpu, [&]() {\ + iter.serial_for_each( \ + [&](char** data_, const int64_t* strides, int64_t n) { \ + scalar_t* out_data = reinterpret_cast(data_[0]); \ + scalar_t* in_data = reinterpret_cast(data_[1]); \ + int64_t out_stride = strides[0] / sizeof(scalar_t); \ + int64_t in_stride = strides[1] / sizeof(scalar_t); \ + if (out_stride == 1 && in_stride == 1) { \ + vml::v##op(out_data, in_data, n); \ + } else { \ + static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t); \ + for (int64_t i = 0; i < n; i += WIDTH) { \ + scalar_t buffer[WIDTH]; \ + int64_t width = WIDTH; \ + width = std::min(width, n - i); \ + for (int64_t j = 0; j < width; j++) \ + buffer[j] = in_data[in_stride * (i + j)]; \ + vml::v##op(buffer, buffer, width); \ + for (int64_t j = 0; j < width; j++) \ + out_data[out_stride * (i + j)] = buffer[j]; \ + } \ + } \ + }, \ + {0, iter.numel()}); \ + }); \ + } \ + REGISTER_DISPATCH(op##_stub, &op##_kernel) + } // anonymous namespace REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel); REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel); REGISTER_DISPATCH(bernoulli_mkl_stub, &bernoulli_mkl_kernel); REGISTER_DISPATCH(abs_stub, &abs_kernel); +REGISTER_DISPATCH(angle_stub, &angle_kernel); +REGISTER_DISPATCH(real_stub, &real_kernel); +REGISTER_DISPATCH(imag_stub, &imag_kernel); +REGISTER_DISPATCH(conj_stub, &conj_kernel); REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel); REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel); REGISTER_DISPATCH(frac_stub, &frac_kernel); @@ -338,29 +413,29 @@ REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel); // IMPLEMENT_FLOAT_KERNEL(ALL, abs) -IMPLEMENT_FLOAT_KERNEL(FLOATING, acos) -IMPLEMENT_FLOAT_KERNEL(FLOATING, asin) -IMPLEMENT_FLOAT_KERNEL(FLOATING, atan) -IMPLEMENT_FLOAT_KERNEL(FLOATING, ceil) -IMPLEMENT_FLOAT_KERNEL(FLOATING, cos) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, acos) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, asin) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, atan) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, ceil) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, cos) // IMPLEMENT_FLOAT_KERNEL(FLOATING, cosh) IMPLEMENT_FLOAT_KERNEL(FLOATING, erf) IMPLEMENT_FLOAT_KERNEL(FLOATING, erfc) IMPLEMENT_FLOAT_KERNEL(FLOATING, erfinv) -IMPLEMENT_FLOAT_KERNEL(FLOATING, exp) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, exp) IMPLEMENT_FLOAT_KERNEL(FLOATING, expm1) -IMPLEMENT_FLOAT_KERNEL(FLOATING, floor) -IMPLEMENT_FLOAT_KERNEL(FLOATING, log) -IMPLEMENT_FLOAT_KERNEL(FLOATING, log10) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, floor) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, log) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, log10) IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p) IMPLEMENT_FLOAT_KERNEL(FLOATING, log2) -IMPLEMENT_FLOAT_KERNEL(FLOATING, round) -IMPLEMENT_FLOAT_KERNEL(FLOATING, sin) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, round) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, sin) // IMPLEMENT_FLOAT_KERNEL(FLOATING, sinh) -IMPLEMENT_FLOAT_KERNEL(FLOATING, sqrt) -IMPLEMENT_FLOAT_KERNEL(FLOATING, tan) -IMPLEMENT_FLOAT_KERNEL(FLOATING, tanh) -IMPLEMENT_FLOAT_KERNEL(FLOATING, trunc) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, sqrt) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, tan) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, tanh) +IMPLEMENT_COMPLEX_KERNEL(FLOATING, trunc) IMPLEMENT_FLOAT_KERNEL(FLOATING, lgamma) }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/zmath.h b/aten/src/ATen/native/cpu/zmath.h new file mode 100644 index 00000000000..9ccfc1c1597 --- /dev/null +++ b/aten/src/ATen/native/cpu/zmath.h @@ -0,0 +1,160 @@ +#pragma once + +// Complex number math operations that act as no-ops for other dtypes. +#include + +namespace at { namespace native { +namespace { + +template +struct ztype { + using value_t = TYPE; +}; + +template <> +struct ztype> { + using value_t = double; +}; + +template <> +struct ztype> { + using value_t = float; +}; + +template +inline VALUE_TYPE zabs (SCALAR_TYPE z) { + return z; +} + +template<> +inline float zabs > (std::complex z) { + return std::abs(z); +} + +template<> +inline double zabs > (std::complex z) { + return std::abs(z); +} + +template +inline TYPE angle_impl (TYPE z) { + return 0; +} + +template<> +inline std::complex angle_impl > (std::complex z) { + return std::complex(std::arg(z), 0.0); +} + +template<> +inline std::complex angle_impl > (std::complex z) { + return std::complex(std::arg(z), 0.0); +} + +template +inline TYPE real_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline std::complex real_impl > (std::complex z) { + return std::complex(std::real(z), 0.0); +} + +template<> +inline std::complex real_impl > (std::complex z) { + return std::complex(std::real(z), 0.0); +} + +template +inline TYPE imag_impl (TYPE z) { + return 0; +} + +template<> +inline std::complex imag_impl > (std::complex z) { + return std::complex(std::imag(z), 0.0); +} + +template<> +inline std::complex imag_impl > (std::complex z) { + return std::complex(std::imag(z), 0.0); +} + +template +inline TYPE conj_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline std::complex conj_impl > (std::complex z) { + return std::complex(std::real(z), -std::imag(z)); +} + +template<> +inline std::complex conj_impl > (std::complex z) { + return std::complex(std::real(z), -std::imag(z)); +} + +template +inline TYPE ceil_impl (TYPE z) { + return std::ceil(z); +} + +template <> +inline std::complex ceil_impl (std::complex z) { + return std::complex(std::ceil(std::real(z)), std::ceil(std::imag(z))); +} + +template <> +inline std::complex ceil_impl (std::complex z) { + return std::complex(std::ceil(std::real(z)), std::ceil(std::imag(z))); +} + +template +inline TYPE floor_impl (TYPE z) { + return std::floor(z); +} + +template <> +inline std::complex floor_impl (std::complex z) { + return std::complex(std::floor(std::real(z)), std::floor(std::imag(z))); +} + +template <> +inline std::complex floor_impl (std::complex z) { + return std::complex(std::floor(std::real(z)), std::floor(std::imag(z))); +} + +template +inline TYPE round_impl (TYPE z) { + return std::nearbyint(z); +} + +template <> +inline std::complex round_impl (std::complex z) { + return std::complex(std::nearbyint(std::real(z)), std::nearbyint(std::imag(z))); +} + +template <> +inline std::complex round_impl (std::complex z) { + return std::complex(std::nearbyint(std::real(z)), std::nearbyint(std::imag(z))); +} + +template +inline TYPE trunc_impl (TYPE z) { + return std::trunc(z); +} + +template <> +inline std::complex trunc_impl (std::complex z) { + return std::complex(std::trunc(std::real(z)), std::trunc(std::imag(z))); +} + +template <> +inline std::complex trunc_impl (std::complex z) { + return std::complex(std::trunc(std::real(z)), std::trunc(std::imag(z))); +} + +} // end namespace +}} //end at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0799f015bbf..bb812f3adc8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -196,6 +196,54 @@ CPU: _abs_out_cpu CUDA: _abs_out_cuda +- func: angle(Tensor self) -> Tensor + variants: function, method + supports_named_tensor: True + named_guard: False + +- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + named_guard: False + supports_named_tensor: True + dispatch: + CPU: _angle_out_cpu + CUDA: _abs_out_cuda + +- func: real(Tensor self) -> Tensor + variants: function, method + named_guard: False + supports_named_tensor: True + +- func: real.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + named_guard: False + supports_named_tensor: True + dispatch: + CPU: _real_out_cpu + CUDA: _abs_out_cuda + +- func: imag(Tensor self) -> Tensor + variants: function, method + named_guard: False + supports_named_tensor: True + +- func: imag.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + named_guard: False + supports_named_tensor: True + dispatch: + CPU: _imag_out_cpu + CUDA: _abs_out_cuda + +- func: conj(Tensor self) -> Tensor + variants: function, method + named_guard: False + supports_named_tensor: True + +- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + named_guard: False + supports_named_tensor: True + dispatch: + CPU: _conj_out_cpu + CUDA: _abs_out_cuda + - func: acos(Tensor self) -> Tensor use_c10_dispatcher: full supports_named_tensor: True @@ -3556,12 +3604,6 @@ SparseCUDA: copy_sparse_ requires_tensor: True -- func: numel(Tensor self) -> int - use_c10_dispatcher: full - variants: function, method - device_guard: False - supports_named_tensor: True - - func: unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[] use_c10_dispatcher: unboxed_only variants: function, method diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 59634806be0..e7e312c5967 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -69,6 +69,10 @@ inline Tensor from_blob( return from_blob(data, sizes, detail::defaultStrides(sizes), [](void*) {}, options); } +inline int64_t numel(const Tensor& tensor) { + return tensor.numel(); +} + // function definitions are all static inline because // they are one-line statically dispatched functions that // invoke the actual dynamic dispatch on the correct argument diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index f8a0e6f2aa4..02f2b16370a 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -211,6 +211,10 @@ class CAFFE2_API Tensor { return impl_->numel() * impl_->itemsize(); } + int64_t numel() const { + return impl_->numel(); + } + // Length of one array element in bytes. This is the traditional // Numpy naming. size_t itemsize() const { diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 80fdf4fc0f7..239f33372c0 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -26,10 +26,12 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/boxed_fallback_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu diff --git a/aten/src/ATen/test/boxed_fallback_test.cpp b/aten/src/ATen/test/boxed_fallback_test.cpp new file mode 100644 index 00000000000..e34882c80d1 --- /dev/null +++ b/aten/src/ATen/test/boxed_fallback_test.cpp @@ -0,0 +1,146 @@ +#include + +#include + +#include +#include +#include +#include + +#include + +using namespace at; + +// This test file gives an example of a simple use case for "wrapper" +// and "mode" style tensor type ids. In both cases, the implementation +// of the wrapper/mode simply passes through the call to underlying JIT +// implementation (so the wrapper/mode doesn't actually do anything), +// but this could be used as a starting point to do more interesting things. + +// TODO: This to be rewritten when bwasti sets up direct access to +// JIT data structures +std::shared_ptr getOperator(const char* schema_str) { + auto schema = torch::jit::parseSchema(schema_str); + auto s = Symbol::fromQualString(schema.name()); + auto operators = torch::jit::getAllOperatorsFor(s); + // Find the exact match + std::shared_ptr op; + for (const auto& candidate_op : operators) { + auto candidate_schema = candidate_op->schema(); + // NB: this is a VERY slow equality test + if (candidate_schema == schema) { + op = candidate_op; + break; + } + } + TORCH_INTERNAL_ASSERT(op); + return op; +} + +// Global counter for ease of testing +static int64_t override_call_count = 0; + +// Mode implementation + +void generic_mode_fallback(const char* schema_str, torch::jit::Stack* stack) { + override_call_count++; + auto operation = getOperator(schema_str)->getOperation(); + c10::impl::ExcludeTensorTypeIdGuard guard(TensorTypeId::TESTING_ONLY_GenericModeTensorId); + auto offset = operation(*stack); + TORCH_INTERNAL_ASSERT(offset == 0); +} + +// Wrapper implementation + +struct GenericWrapperTensorImpl : public c10::TensorImpl { + explicit GenericWrapperTensorImpl(at::Tensor rep) + : TensorImpl( + c10::TensorTypeSet(c10::TensorTypeId::TESTING_ONLY_GenericWrapperTensorId), + rep.dtype(), + rep.device() + // TODO: propagate size! + ) + , rep_(std::move(rep)) {} + + at::Tensor rep_; +}; + +void generic_wrapper_fallback(const char* schema_str, torch::jit::Stack* stack) { + override_call_count++; + auto op = getOperator(schema_str); + auto operation = op->getOperation(); + + const auto& schema = op->schema(); + auto num_arguments = schema.arguments().size(); + auto num_returns = schema.returns().size(); + + // Unwrap all arguments + auto args = torch::jit::pop(*stack, num_arguments); + for (size_t i = 0; i < num_arguments; i++) { + // TODO: Handle tensor list + if (args[i].isTensor()) { + auto* impl = args[i].unsafeToTensorImpl(); + if (impl->type_set().has(TensorTypeId::TESTING_ONLY_GenericWrapperTensorId)) { + auto* wrapper = static_cast(impl); + torch::jit::push(*stack, wrapper->rep_); // no move! + } else { + torch::jit::push(*stack, std::move(args[i])); + } + } else { + torch::jit::push(*stack, std::move(args[i])); + } + } + + auto offset = operation(*stack); + + // Rewrap outputs + auto rets = torch::jit::pop(*stack, num_returns); + for (size_t i = 0; i < num_returns; i++) { + // TODO: Handle tensor list + if (args[i].isTensor()) { + torch::jit::push(*stack, at::detail::make_tensor(std::move(std::move(args[i]).toTensor())) ); // yes move! + } else { + torch::jit::push(*stack, std::move(args[i])); + } + } + + TORCH_INTERNAL_ASSERT(offset == 0); +} + +// As the current API does not support unregistering fallback boxed ops, +// settings of these values are PROCESS global. Therefore the environment +// here. +class Environment : public ::testing::Environment { + public: + virtual ~Environment() {} + + void SetUp() override { + globalATenDispatch().registerFallbackBoxedOp(TensorTypeId::TESTING_ONLY_GenericWrapperTensorId, &generic_wrapper_fallback); + globalATenDispatch().registerFallbackBoxedOp(TensorTypeId::TESTING_ONLY_GenericModeTensorId, &generic_mode_fallback); + } + + void TearDown() override {} +}; + +::testing::Environment* const env = + ::testing::AddGlobalTestEnvironment(new Environment); + +// There's a case to be made that a more comprehensive test suite would be able +// to capture many more edge cases. This test suite is just to show that +// basic functionality works. + +TEST(BoxedFallbackTest, TestBoxedFallbackWithMode) { + c10::impl::IncludeTensorTypeIdGuard guard(TensorTypeId::TESTING_ONLY_GenericModeTensorId); + + override_call_count = 0; + Tensor a = ones({5, 5}, kDouble); + Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false); + ASSERT_EQ(override_call_count, 2); +} + +TEST(BoxedFallbackTest, TestBoxedFallbackWithWrapper) { + override_call_count = 0; + Tensor a = at::detail::make_tensor(ones({5, 5}, kDouble)); + Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false); + ASSERT_EQ(override_call_count, 1); +} diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index c7ac6ed042c..be8cc60ed7d 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -4,6 +4,10 @@ #include #include +#include + +#include + using namespace at; static int test_int; diff --git a/aten/src/ATen/test/variant_test.cpp b/aten/src/ATen/test/variant_test.cpp new file mode 100644 index 00000000000..2f97ea66e53 --- /dev/null +++ b/aten/src/ATen/test/variant_test.cpp @@ -0,0 +1,67 @@ +#include + +#include + +namespace testns { + +namespace enumtype { + // NOTE: We need to provide the default constructor for each struct, + // otherwise Clang 3.8 would complain: + // ``` + // error: default initialization of an object of const type 'const enumtype::Enum1' + // without a user-provided default constructor + // ``` + struct Enum1 { Enum1() {}; }; + struct Enum2 { Enum2() {}; }; + struct Enum3 { Enum3() {}; }; +} // namespace enumtype + +struct enum_name { + std::string operator()(enumtype::Enum1& v) const { + return "Enum1"; + } + std::string operator()(enumtype::Enum2& v) const { + return "Enum2"; + } + std::string operator()(enumtype::Enum3& v) const { + return "Enum3"; + } +}; + +const enumtype::Enum1 kEnum1; +const enumtype::Enum2 kEnum2; +const enumtype::Enum3 kEnum3; + +} // namespace testns + +std::string func(c10::variant v) { + if (c10::get_if(&v)) { + return "Enum1"; + } else if (c10::get_if(&v)) { + return "Enum2"; + } else if (c10::get_if(&v)) { + return "Enum3"; + } else { + return "Unsupported enum"; + } +} + +TEST(VariantTest, Basic) { + ASSERT_EQ(func(testns::kEnum1), "Enum1"); + ASSERT_EQ(func(testns::kEnum2), "Enum2"); + ASSERT_EQ(func(testns::kEnum3), "Enum3"); + + c10::variant v; + { + v = testns::kEnum1; + ASSERT_EQ(c10::visit(testns::enum_name{}, v), "Enum1"); + } + { + v = testns::kEnum2; + ASSERT_EQ(c10::visit(testns::enum_name{}, v), "Enum2"); + } + { + v = testns::kEnum3; + ASSERT_EQ(c10::visit(testns::enum_name{}, v), "Enum3"); + } +} diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index c2908e7a103..38cc7fae570 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -298,29 +298,20 @@ static void THTensor_(addmmImpl)(THTensor *r_, THTensor *t, THTensor *m1, THTens int64_t ldm1_ = (transpose_m1 == 'n' ? m1_->stride((transpose_r == 'n' ? 1 : 0)) : m1_->stride((transpose_r == 'n' ? 0 : 1))); int64_t ldm2_ = (transpose_m2 == 'n' ? m2_->stride((transpose_r == 'n' ? 1 : 0)) : m2_->stride((transpose_r == 'n' ? 0 : 1))); - // Don't go through GEMM if result is empty matrix, since this is not - // supported by BLAS. - if (m != 0 && n != 0) { - if (k == 0) { - THTensor_(mul)(r__, r__, beta); - } else { - /* do the operation */ - THBlas_(gemm)(transpose_m1, - transpose_m2, - m, - n, - k, - alpha, - m1_->data(), - ldm1_, - m2_->data(), - ldm2_, - beta, - r__->data(), - ldr__); - } - } - + /* do the operation */ + THBlas_(gemm)(transpose_m1, + transpose_m2, + m, + n, + k, + alpha, + m1_->data(), + ldm1_, + m2_->data(), + ldm2_, + beta, + r__->data(), + ldr__); /* free intermediate variables */ if(free_m1) diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu index a51115bb6df..e1a5418b87b 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ b/aten/src/THC/generic/THCTensorMathBlas.cu @@ -305,20 +305,6 @@ static void THCTensor_(addmmImpl)(THCState *state, THCTensor *r_, THCTensor *t, } } - // Special casing for empty matrices - if (r_->size(0) == 0 || r_->size(1) == 0) { - // No multiplication needed for case of empty result matrix. - return; - } else if (m1->size(1) == 0) { - // k == 0 - if (ScalarConvert::to(beta) != 0.0) { - THCTensor_(mul)(state, r_, r_, beta); - } else { - THCTensor_(zero)(state, r_); - } - return; - } - /* r_ */ if(r_->stride(0) == 1 && r_->stride(1) != 0) diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 4bf8295c53b..aa368e23d37 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -12,6 +12,7 @@ import torch # needs to be imported after torch import cpp_extension # noqa +import cpp_extension # noqa import benchmark_utils from collections import namedtuple diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index 0371c74b7f2..fe558dc8890 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -94,11 +94,17 @@ class TorchBenchmarkBase(object): """ this is a globally unique name which can be used to label a specific test """ + + # This is a list of attributes which will not be included + # in the test name. + skip_key_list = ['device'] + test_name_str = [] for key in kargs: value = kargs[key] test_name_str.append( - key + str(value if type(value) != bool else int(value))) + ('' if key in skip_key_list else key) + + str(value if type(value) != bool else int(value))) name = (self.module_name() + '_' + '_'.join(test_name_str)).replace(" ", "") return name diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index 22dd8b80baf..ff75eb2b90b 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -7,6 +7,7 @@ import numpy as np import itertools import random import os +import bisect """Performance microbenchmarks's utils. @@ -14,6 +15,8 @@ import os This module contains utilities for writing microbenchmark tests. """ +# Here are the reserved keywords in the benchmark suite +_reserved_keywords = {"probs", "total_samples", "tags"} def shape_to_string(shape): return ', '.join([str(x) for x in shape]) @@ -109,32 +112,159 @@ def cross_product_configs(**configs): def config_list(**configs): - """ - Take specific inputs from users - For example, given + """ Generate configs based on the list of input shapes. + This function will take input shapes specified in a list from user. Besides + that, all other parameters will be cross producted first and each of the + generated list will be merged with the input shapes list. + + Reserved Args: + attr_names(reserved): a list of names for input shapes. + attrs(reserved): a list of values for each input shape. + corss_product: a dictionary of attributes which will be + cross producted with the input shapes. + tags(reserved): a tag used to filter inputs. + + Here is an example: attrs = [ [1, 2], [4, 5], - ] - attr_names = ["M", "N"] - we will generate (({'M': 1}, {'N' : 2}), - ({'M': 4}, {'N' : 5})) + ], + attr_names = ['M', 'N'], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + + we will generate [[{'M': 1}, {'N' : 2}, {'device' : 'cpu'}], + [{'M': 1}, {'N' : 2}, {'device' : 'cuda'}], + [{'M': 4}, {'N' : 5}, {'device' : 'cpu'}], + [{'M': 4}, {'N' : 5}, {'device' : 'cuda'}]] """ generated_configs = [] - if "attrs" not in configs: + reserved_names = ['attrs', 'attr_names', 'tags'] + if any(attr not in configs for attr in reserved_names): raise ValueError("Missing attrs in configs") - for inputs in configs["attrs"]: - tmp_result = [{configs["attr_names"][i] : input_value} + + cross_configs = None + if 'cross_product_configs' in configs: + cross_configs = cross_product_configs(**configs['cross_product_configs']) + + for inputs in configs['attrs']: + tmp_result = [{configs['attr_names'][i] : input_value} for i, input_value in enumerate(inputs)] # TODO(mingzhe0908): - # If multiple "tags" were provided, do they get concat? - # If a config has both ["short", "medium"], it should match - # both "short" and "medium" tag-filter? - tmp_result.append({"tags" : '_'.join(configs["tags"])}) - generated_configs.append(tmp_result) + # If multiple 'tags' were provided, do they get concat? + # If a config has both ['short', 'medium'], it should match + # both 'short' and 'medium' tag-filter? + tmp_result.append({'tags' : '_'.join(configs['tags'])}) + if cross_configs: + generated_configs += [tmp_result + list(config) for config in cross_configs] + else: + generated_configs.append(tmp_result) + return generated_configs +def attr_probs(**probs): + """ return the inputs in a dictionary + """ + return probs + + +class RandomSample(object): + + def __init__(self, configs): + self.saved_cum_distribution = {} + self.configs = configs + + def _distribution_func(self, key, weights): + """ this is a cumulative distribution function used for random sampling inputs + """ + if key in self.saved_cum_distribution: + return self.saved_cum_distribution[key] + + total = sum(weights) + result = [] + cumsum = 0 + for w in weights: + cumsum += w + result.append(cumsum / total) + self.saved_cum_distribution[key] = result + return result + + def _random_sample(self, key, values, weights): + """ given values and weights, this function randomly sample values based their weights + """ + # TODO(mingzhe09088): cache the results to avoid recalculation overhead + assert len(values) == len(weights) + _distribution_func_vals = self._distribution_func(key, weights) + x = random.random() + idx = bisect.bisect(_distribution_func_vals, x) + + assert idx <= len(values), "Wrong index value is returned" + # Due to numerical property, the last value in cumsum could be slightly + # smaller than 1, and lead to the (index == len(values)). + if idx == len(values): + idx -= 1 + return values[idx] + + def get_one_set_of_inputs(self): + tmp_attr_list = [] + for key, values in self.configs.items(): + if key in _reserved_keywords: + continue + value = self._random_sample(key, values, self.configs["probs"][str(key)]) + tmp_results = {key : value} + tmp_attr_list.append(tmp_results) + return (tmp_attr_list) + + +def random_sample_configs(**configs): + """ + This function randomly sample values from the given inputs based on + their weights. + Here is an example showing what are the expected inputs and outpus from this function: + M = [1, 2], + N = [4, 5], + K = [7, 8], + probs = attr_probs( + M = [0.7, 0.2], + N = [0.5, 0.2], + K = [0.6, 0.2], + ), + total_samples=10, + this function will generate + [ + [{'K': 7}, {'M': 1}, {'N': 4}], + [{'K': 7}, {'M': 2}, {'N': 5}], + [{'K': 8}, {'M': 2}, {'N': 4}], + ... + ] + Note: + The probs is optional. Without them, it implies everything is 1. The probs doesn't + have to reflect the actual normalized probability, the implementation will + normalize it. + TODO (mingzhe09088): + (1): a lambda that accepts or rejects a config as a sample. For example: for matmul + with M, N, and K, this function could get rid of (M * N * K > 1e8) to filter out + very slow benchmarks. + (2): Make sure each sample is unique. If the number of samples are larger than the + total combinations, just return the cross product. Otherwise, if the number of samples + is close to the number of cross-products, it is numerical safer to generate the list + that you don't want, and remove them. + """ + if "probs" not in configs: + raise ValueError("probs is missing. Consider adding probs or" + "using other config functions") + + configs_attrs_list = [] + randomsample = RandomSample(configs) + for i in range(configs["total_samples"]): + tmp_attr_list = randomsample.get_one_set_of_inputs() + tmp_attr_list.append({"tags" : '_'.join(configs["tags"])}) + configs_attrs_list.append(tmp_attr_list) + return configs_attrs_list + + def op_list(**configs): """Generate a list of ops organized in a specific format. It takes two parameters which are "attr_names" and "attr". diff --git a/benchmarks/operator_benchmark/common/tests/pt_configs_list_test.py b/benchmarks/operator_benchmark/common/tests/pt_configs_list_test.py new file mode 100644 index 00000000000..b48a4b95091 --- /dev/null +++ b/benchmarks/operator_benchmark/common/tests/pt_configs_list_test.py @@ -0,0 +1,40 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import operator_benchmark as op_bench +import torch + +"""Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch.""" + +add_short_configs = op_bench.config_list( + attr_names=['M', 'N', 'K'], + attrs=[ + [8, 16, 32], + [16, 16, 64], + [64, 64, 128], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + 'dtype': [torch.float, torch.float64], + }, + tags=['short'], +) + + +class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype): + self.input_one = torch.rand(M, N, K, device=device, dtype=dtype, requires_grad=True) + self.input_two = torch.rand(M, N, K, device=device, dtype=dtype) + self.set_module_name('add') + + def forward(self): + return torch.add(self.input_one, self.input_two) + + +op_bench.generate_pt_test(add_short_configs, AddBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/common/tests/random_sample_test.py b/benchmarks/operator_benchmark/common/tests/random_sample_test.py new file mode 100644 index 00000000000..d25ea2b9247 --- /dev/null +++ b/benchmarks/operator_benchmark/common/tests/random_sample_test.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import operator_benchmark as op_bench +import torch + + +configs = op_bench.random_sample_configs( + M=[1, 2, 3, 4, 5, 6], + N=[7, 8, 9, 10, 11, 12], + K=[13, 14, 15, 16, 17, 18], + # probs saves the weights of each value + probs=op_bench.attr_probs( + M=[0.5, 0.2, 0.1, 0.05, 0.03, 0.1], + N=[0.1, 0.3, 0.4, 0.02, 0.03, 0.04], + K=[0.03, 0.6, 0.04, 0.02, 0.03, 0.01], + ), + # this is the number of returned inputs + total_samples=10, + tags=["short"], +) + + +class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K): + self.input_one = torch.rand(M, N, K) + self.input_two = torch.rand(M, N, K) + self.set_module_name("add") + + def forward(self): + return torch.add(self.input_one, self.input_two) + + +op_bench.generate_pt_test(configs, AddBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index 3be7fb27749..1e4c5222cb9 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -11,33 +11,49 @@ import torch # Configs for PT add operator add_long_configs = op_bench.cross_product_configs( M=[8, 64, 128], - N=range(2, 10, 3), - K=[2 ** x for x in range(0, 3)], + N=range(2, 128, 64), + K=[8 ** x for x in range(0, 3)], + device=['cpu', 'cuda'], tags=["long"] ) add_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], attrs=[ [64, 64, 64], [64, 64, 128], ], - attr_names=["M", "N", "K"], - tags=["short"], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"], ) class AddBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K): - self.input_one = torch.rand(M, N, K) - self.input_two = torch.rand(M, N, K) + def init(self, M, N, K, device): + self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + self.input_two = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) self.set_module_name("add") def forward(self): return torch.add(self.input_one, self.input_two) +# The generated test names based on add_short_configs will be in the following pattern: +# add_M8_N16_K32_devicecpu +# add_M8_N16_K32_devicecuda +# add_M8_N16_K32_devicecpu_bwdall +# add_M8_N16_K32_devicecpu_bwd1 +# add_M8_N16_K32_devicecpu_bwd2 +# add_M8_N16_K32_devicecuda_bwdall +# add_M8_N16_K32_devicecuda_bwd1 +# add_M8_N16_K32_devicecuda_bwd2 +# ... +# Those names can be used to filter tests. op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) +op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/as_strided_test.py b/benchmarks/operator_benchmark/pt/as_strided_test.py new file mode 100644 index 00000000000..12bdd7a973c --- /dev/null +++ b/benchmarks/operator_benchmark/pt/as_strided_test.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import operator_benchmark as op_bench +import torch + + +"""Microbenchmarks for as_strided operator""" + + +# Configs for PT as_strided operator +split_short_configs = op_bench.cross_product_configs( + M=[256, 512], + N=[256, 512], + size=[(32, 32), (64, 64)], + stride=[(1, 1), (2, 2)], + storage_offset=[0, 1], + tags=['short'] +) + + +class As_stridedBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, size, stride, storage_offset): + self.input_one = torch.rand(M, N) + self.size = size + self.stride = stride + self.storage_offset = storage_offset + self.set_module_name('as_strided') + + def forward(self): + return torch.as_strided( + self.input_one, self.size, self.stride, self.storage_offset) + + +op_bench.generate_pt_test(split_short_configs, As_stridedBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/binaries/bench_gen/bench_gen.py b/binaries/bench_gen/bench_gen.py old mode 100644 new mode 100755 diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 746cdef0e71..128ce4fe872 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -365,10 +365,6 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { if (a == ud || b == ud) { return ScalarType::Undefined; } - if (isComplexType(a) || isComplexType(b)) { - AT_ERROR( - "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be for ", toString(a), " and ", toString(b)); - } // For QInt types, we only allow exact match if (isQIntType(a) && a == b) { diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index ac90791ccc4..69cd7147ab7 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -224,12 +224,4 @@ at::DataPtr PlacementDeleteContext::makeDataPtr( AutogradMetaInterface::~AutogradMetaInterface() {} -bool NonVariableTypeMode::is_enabled() { - return !impl::tls_variable_is_enabled(); -} - -void NonVariableTypeMode::set_enabled(bool enabled) { - impl::tls_variable_set_enabled(!enabled); -} - } // namespace c10 diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 4e8dfe8bfc7..393266aeac0 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -138,11 +139,6 @@ struct C10_API AutogradMetaInterface { virtual ~AutogradMetaInterface(); }; -struct C10_API NonVariableTypeMode { - static bool is_enabled(); - static void set_enabled(bool enabled); -}; - struct C10_API NamedTensorMetaInterface { virtual ~NamedTensorMetaInterface() {}; virtual std::unique_ptr clone() const { @@ -808,7 +804,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * True if a tensor is a variable. See Note [Tensor versus Variable in C++] */ bool is_variable() const { - return autograd_meta_ != nullptr && !at::NonVariableTypeMode::is_enabled(); + return autograd_meta_ != nullptr && !impl::tls_local_tensor_type_set().excluded_.has(TensorTypeId::VariableTensorId); } /** diff --git a/c10/core/TensorTypeId.cpp b/c10/core/TensorTypeId.cpp index 1dd4476ae6c..9f63c3829ba 100644 --- a/c10/core/TensorTypeId.cpp +++ b/c10/core/TensorTypeId.cpp @@ -40,6 +40,10 @@ const char* toString(TensorTypeId t) { return "ComplexCUDATensorId"; case TensorTypeId::VariableTensorId: return "VariableTensorId"; + case TensorTypeId::TESTING_ONLY_GenericModeTensorId: + return "TESTING_ONLY_GenericModeTensorId"; + case TensorTypeId::TESTING_ONLY_GenericWrapperTensorId: + return "TESTING_ONLY_GenericWrapperTensorId"; default: return "UNKNOWN_TENSOR_TYPE_ID"; } diff --git a/c10/core/TensorTypeId.h b/c10/core/TensorTypeId.h index d01ee9d5f3e..e1bcbc2d26c 100644 --- a/c10/core/TensorTypeId.h +++ b/c10/core/TensorTypeId.h @@ -49,6 +49,19 @@ enum class TensorTypeId : uint8_t { VariableTensorId, + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptible use is within a single + // process test. Use it by creating a TensorImpl with this TensorTypeId, and + // then registering operators to operate on this type id. + TESTING_ONLY_GenericWrapperTensorId, + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptible use is within a ingle + // process test. Use it by toggling the mode on and off via + // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators + // to operate on this type id. + TESTING_ONLY_GenericModeTensorId, + NumTensorIds, // Sentinel }; diff --git a/c10/core/impl/LocalTensorTypeSet.cpp b/c10/core/impl/LocalTensorTypeSet.cpp index 9f4de688009..dfda4e0ce75 100644 --- a/c10/core/impl/LocalTensorTypeSet.cpp +++ b/c10/core/impl/LocalTensorTypeSet.cpp @@ -8,34 +8,55 @@ namespace impl { namespace { /// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, -/// thread_local is not supported. In that case, we don't provide -/// `at::NonVariableTypeMode`. +/// thread_local is not supported. #ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY -// NB: Zero initialized! -thread_local uint64_t raw_excluded; +// NB: POD, zero initialized! +thread_local PODLocalTensorTypeSet raw_local_tensor_type_set; #else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) -uint64_t raw_excluded = 0; +static PODLocalTensorTypeSet raw_local_tensor_type_set; #endif +} // anonymous namespace + +LocalTensorTypeSet tls_local_tensor_type_set() { + return raw_local_tensor_type_set; } -TensorTypeSet tls_excluded_tensor_type_set() { - return TensorTypeSet(TensorTypeSet::RAW, raw_excluded); +// We could have also just snapshotted the entire state. I'm not sure which is +// better; but right now only the guard API is allowed so the two cases are +// not distinguishable. + +IncludeTensorTypeIdGuard::IncludeTensorTypeIdGuard(TensorTypeId x) + : tls_(&raw_local_tensor_type_set) + , id_(x) + , prev_state_(tls_->included().has(x)) { + if (!prev_state_) { + tls_->set_included(tls_->included().add(x)); + } } -bool tls_variable_is_enabled() { - return !tls_excluded_tensor_type_set().has(TensorTypeId::VariableTensorId); +IncludeTensorTypeIdGuard::~IncludeTensorTypeIdGuard() { + if (!prev_state_) { + tls_->set_included(tls_->included().remove(id_)); + } } -void tls_variable_set_enabled(bool enabled) { - if (enabled) { - raw_excluded = tls_excluded_tensor_type_set().remove(TensorTypeId::VariableTensorId).raw_repr(); - } else { - raw_excluded = tls_excluded_tensor_type_set().add(TensorTypeId::VariableTensorId).raw_repr(); +ExcludeTensorTypeIdGuard::ExcludeTensorTypeIdGuard(TensorTypeId x) + : tls_(&raw_local_tensor_type_set) + , id_(x) + , prev_state_(tls_->excluded().has(x)) { + if (!prev_state_) { + tls_->set_excluded(tls_->excluded().add(x)); + } +} + +ExcludeTensorTypeIdGuard::~ExcludeTensorTypeIdGuard() { + if (!prev_state_) { + tls_->set_excluded(tls_->excluded().remove(id_)); } } diff --git a/c10/core/impl/LocalTensorTypeSet.h b/c10/core/impl/LocalTensorTypeSet.h index b049dbaa868..1517c5beb9e 100644 --- a/c10/core/impl/LocalTensorTypeSet.h +++ b/c10/core/impl/LocalTensorTypeSet.h @@ -1,22 +1,80 @@ +#pragma once + #include -// TLS management for TensorTypeSet +// TLS management for TensorTypeSet (the "local" TensorTypeSet(s)) // -// This manages thread-local TensorTypeSet of excluded keys which disqualify -// tensor types from dispatch. Keys which are in this set, even if they appear -// in a list of potential valid keys on a tensor, are not considered for -// dispatch. This is used to, for example, turn off autograd after we have -// handled autograd for a top-level element. +// This manages two thread-local TensorTypeSets: // -// Originally, I implemented this as storing the inverted set, but -// TLS is defined to be zero-initialized, so this doesn't actually work -// (you want the set to be -1 initialized). +// - The included type set, which adds a tensor type for consideration +// in dispatch. (For example, you might add ProfilingTensorId to +// the included type set to turn on profiling on all tensor operations.) +// +// - The excluded type set, which disqualifies a tensor type from dispatch. +// (For example, after redispatching on variable, we disqualify +// VariableTensorId so we don't attempt to handle variable again.) +// (Exclusion wins over inclusion.) +// +// NB: Originally, I implemented the excluded type set as storing the inverted +// set, but TLS is defined to be zero-initialized, so this doesn't actually work +// (if it's inverted, you want the set to be -1 initialized). namespace c10 { namespace impl { -C10_API bool tls_variable_is_enabled(); -C10_API void tls_variable_set_enabled(bool enabled); -C10_API TensorTypeSet tls_excluded_tensor_type_set(); +// POD version of LocalTensorTypeSet. Declared here just so that +// we can put it in the guards. +struct C10_API PODLocalTensorTypeSet { + uint64_t included_; + uint64_t excluded_; + + TensorTypeSet included() const { + return TensorTypeSet(TensorTypeSet::RAW, included_); + } + TensorTypeSet excluded() const { + return TensorTypeSet(TensorTypeSet::RAW, excluded_); + } + + void set_included(TensorTypeSet x) { + included_ = x.raw_repr(); + } + void set_excluded(TensorTypeSet x) { + excluded_ = x.raw_repr(); + } +}; +static_assert(std::is_pod::value, "PODLocalTensorTypeSet must be a POD type."); + +struct C10_API LocalTensorTypeSet { + /* implicit */ LocalTensorTypeSet(PODLocalTensorTypeSet x) + : included_(x.included()), excluded_(x.excluded()) {} + TensorTypeSet included_; + TensorTypeSet excluded_; +}; + +C10_API LocalTensorTypeSet tls_local_tensor_type_set(); + +class C10_API IncludeTensorTypeIdGuard { +public: + IncludeTensorTypeIdGuard(TensorTypeId); + ~IncludeTensorTypeIdGuard(); +private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalTensorTypeSet* tls_; + TensorTypeId id_; + bool prev_state_; +}; + +class C10_API ExcludeTensorTypeIdGuard { +public: + ExcludeTensorTypeIdGuard(TensorTypeId); + ~ExcludeTensorTypeIdGuard(); +private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalTensorTypeSet* tls_; + TensorTypeId id_; + bool prev_state_; +}; }} // namespace c10::impl diff --git a/c10/util/Complex.h b/c10/util/Complex.h new file mode 100644 index 00000000000..6813529454d --- /dev/null +++ b/c10/util/Complex.h @@ -0,0 +1,17 @@ +#pragma once + +#include + + +namespace std { + +template struct is_complex_t : public std::false_type {}; +template struct is_complex_t> : public std::true_type {}; + +template <> +class numeric_limits> : public numeric_limits {}; + +template <> +class numeric_limits> : public numeric_limits {}; + +} // namespace std diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 76ae3b26a29..4cda3cc49a0 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -388,13 +388,23 @@ return UINT64_MAX >> (64 - N); } + // Ignore the false warning "Arithmetic overflow" for MSVC + #ifdef _MSC_VER + # pragma warning(push) + # pragma warning(disable : 4146) + #endif + /// Gets the minimum value for a N-bit signed integer. inline int64_t minIntN(int64_t N) { assert(N > 0 && N <= 64 && "integer width out of range"); - return -(UINT64_C(1)<<(N-1)); + return -(UINT64_C(1) << (N - 1)); } + #ifdef _MSC_VER + # pragma warning(pop) + #endif + /// Gets the maximum value for a N-bit signed integer. inline int64_t maxIntN(int64_t N) { assert(N > 0 && N <= 64 && "integer width out of range"); diff --git a/c10/util/variant.h b/c10/util/variant.h new file mode 100644 index 00000000000..e7daff06b14 --- /dev/null +++ b/c10/util/variant.h @@ -0,0 +1,2854 @@ +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) +// +// From https://github.com/mpark/variant +// +// C10 +// - Move to `c10` namespace. +// - Rename namespace `detail` to `detail_`, to not conflict with existing +// c10 implementations in `detail` namespace. +// - `struct in_place_t` is renamed to `struct variant_in_place_t`, to not +// conflict with `struct in_place_t` in c10/util/Optional.h. +// - In two functions, the template name reference `I` is changed to +// `detail_::best_match::value` to work around gcc 7.3.1 bug. +// However, this workaround also limits the use cases of `c10::variant`. +// Please see NOTE [gcc 7.3.1 bug workaround] for details. + +#ifndef C10_UTIL_VARIANT_H_ +#define C10_UTIL_VARIANT_H_ + +/* + variant synopsis + +namespace std { + + // 20.7.2, class template variant + template + class variant { + public: + + // 20.7.2.1, constructors + constexpr variant() noexcept(see below); + variant(const variant&); + variant(variant&&) noexcept(see below); + + template constexpr variant(T&&) noexcept(see below); + + template + constexpr explicit variant(in_place_type_t, Args&&...); + + template + constexpr explicit variant( + in_place_type_t, initializer_list, Args&&...); + + template + constexpr explicit variant(in_place_index_t, Args&&...); + + template + constexpr explicit variant( + in_place_index_t, initializer_list, Args&&...); + + // 20.7.2.2, destructor + ~variant(); + + // 20.7.2.3, assignment + variant& operator=(const variant&); + variant& operator=(variant&&) noexcept(see below); + + template variant& operator=(T&&) noexcept(see below); + + // 20.7.2.4, modifiers + template + T& emplace(Args&&...); + + template + T& emplace(initializer_list, Args&&...); + + template + variant_alternative& emplace(Args&&...); + + template + variant_alternative& emplace(initializer_list, Args&&...); + + // 20.7.2.5, value status + constexpr bool valueless_by_exception() const noexcept; + constexpr size_t index() const noexcept; + + // 20.7.2.6, swap + void swap(variant&) noexcept(see below); + }; + + // 20.7.3, variant helper classes + template struct variant_size; // undefined + + template + constexpr size_t variant_size_v = variant_size::value; + + template struct variant_size; + template struct variant_size; + template struct variant_size; + + template + struct variant_size>; + + template struct variant_alternative; // undefined + + template + using variant_alternative_t = typename variant_alternative::type; + + template struct variant_alternative; + template struct variant_alternative; + template struct variant_alternative; + + template + struct variant_alternative>; + + constexpr size_t variant_npos = -1; + + // 20.7.4, value access + template + constexpr bool holds_alternative(const variant&) noexcept; + + template + constexpr variant_alternative_t>& + get(variant&); + + template + constexpr variant_alternative_t>&& + get(variant&&); + + template + constexpr variant_alternative_t> const& + get(const variant&); + + template + constexpr variant_alternative_t> const&& + get(const variant&&); + + template + constexpr T& get(variant&); + + template + constexpr T&& get(variant&&); + + template + constexpr const T& get(const variant&); + + template + constexpr const T&& get(const variant&&); + + template + constexpr add_pointer_t>> + get_if(variant*) noexcept; + + template + constexpr add_pointer_t>> + get_if(const variant*) noexcept; + + template + constexpr add_pointer_t + get_if(variant*) noexcept; + + template + constexpr add_pointer_t + get_if(const variant*) noexcept; + + // 20.7.5, relational operators + template + constexpr bool operator==(const variant&, const variant&); + + template + constexpr bool operator!=(const variant&, const variant&); + + template + constexpr bool operator<(const variant&, const variant&); + + template + constexpr bool operator>(const variant&, const variant&); + + template + constexpr bool operator<=(const variant&, const variant&); + + template + constexpr bool operator>=(const variant&, const variant&); + + // 20.7.6, visitation + template + constexpr see below visit(Visitor&&, Variants&&...); + + // 20.7.7, class monostate + struct monostate; + + // 20.7.8, monostate relational operators + constexpr bool operator<(monostate, monostate) noexcept; + constexpr bool operator>(monostate, monostate) noexcept; + constexpr bool operator<=(monostate, monostate) noexcept; + constexpr bool operator>=(monostate, monostate) noexcept; + constexpr bool operator==(monostate, monostate) noexcept; + constexpr bool operator!=(monostate, monostate) noexcept; + + // 20.7.9, specialized algorithms + template + void swap(variant&, variant&) noexcept(see below); + + // 20.7.10, class bad_variant_access + class bad_variant_access; + + // 20.7.11, hash support + template struct hash; + template struct hash>; + template <> struct hash; + +} // namespace std + +*/ + +#include +#include +#include +#include +#include +#include +#include + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_CONFIG_HPP +#define MPARK_CONFIG_HPP + +// MSVC 2015 Update 3. +#if __cplusplus < 201103L && (!defined(_MSC_VER) || _MSC_FULL_VER < 190024210) +#error "MPark.Variant requires C++11 support." +#endif + +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef __has_include +#define __has_include(x) 0 +#endif + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_attribute(always_inline) || defined(__GNUC__) +#define MPARK_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#elif defined(_MSC_VER) +#define MPARK_ALWAYS_INLINE __forceinline +#else +#define MPARK_ALWAYS_INLINE inline +#endif + +#if __has_builtin(__builtin_addressof) || \ + (defined(__GNUC__) && __GNUC__ >= 7) || defined(_MSC_VER) +#define MPARK_BUILTIN_ADDRESSOF +#endif + +#if __has_builtin(__builtin_unreachable) || defined(__GNUC__) +#define MPARK_BUILTIN_UNREACHABLE __builtin_unreachable() +#elif defined(_MSC_VER) +#define MPARK_BUILTIN_UNREACHABLE __assume(false) +#else +#define MPARK_BUILTIN_UNREACHABLE +#endif + +#if __has_builtin(__type_pack_element) +#define MPARK_TYPE_PACK_ELEMENT +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 200704 && \ + !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 9) +#define MPARK_CPP11_CONSTEXPR +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304 +#define MPARK_CPP14_CONSTEXPR +#endif + +#if __has_feature(cxx_exceptions) || defined(__cpp_exceptions) || \ + (defined(_MSC_VER) && defined(_CPPUNWIND)) +#define MPARK_EXCEPTIONS +#endif + +#if defined(__cpp_generic_lambdas) || defined(_MSC_VER) +#define MPARK_GENERIC_LAMBDAS +#endif + +#if defined(__cpp_lib_integer_sequence) +#define MPARK_INTEGER_SEQUENCE +#endif + +#if defined(__cpp_return_type_deduction) || defined(_MSC_VER) +#define MPARK_RETURN_TYPE_DEDUCTION +#endif + +#if defined(__cpp_lib_transparent_operators) || defined(_MSC_VER) +#define MPARK_TRANSPARENT_OPERATORS +#endif + +#if defined(__cpp_variable_templates) || defined(_MSC_VER) +#define MPARK_VARIABLE_TEMPLATES +#endif + +#if !defined(__GLIBCXX__) || __has_include() // >= libstdc++-5 +#define MPARK_TRIVIALITY_TYPE_TRAITS +#define MPARK_INCOMPLETE_TYPE_TRAITS +#endif + +#endif // MPARK_CONFIG_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_IN_PLACE_HPP +#define MPARK_IN_PLACE_HPP + +#include + + +namespace c10 { + + struct variant_in_place_t { explicit variant_in_place_t() = default; }; + + template + struct in_place_index_t { explicit in_place_index_t() = default; }; + + template + struct in_place_type_t { explicit in_place_type_t() = default; }; + +#ifdef MPARK_VARIABLE_TEMPLATES + constexpr variant_in_place_t in_place{}; + + template constexpr in_place_index_t in_place_index{}; + + template constexpr in_place_type_t in_place_type{}; +#endif + +} // namespace c10 + +#endif // MPARK_IN_PLACE_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_LIB_HPP +#define MPARK_LIB_HPP + +#include +#include +#include +#include + + +#define MPARK_RETURN(...) \ + noexcept(noexcept(__VA_ARGS__)) -> decltype(__VA_ARGS__) { return __VA_ARGS__; } + +namespace c10 { + namespace lib { + template + struct identity { using type = T; }; + + inline namespace cpp14 { + template + struct array { + constexpr const T &operator[](std::size_t index) const { + return data[index]; + } + + T data[N == 0 ? 1 : N]; + }; + + template + using add_pointer_t = typename std::add_pointer::type; + + template + using common_type_t = typename std::common_type::type; + + template + using decay_t = typename std::decay::type; + + template + using enable_if_t = typename std::enable_if::type; + + template + using remove_const_t = typename std::remove_const::type; + + template + using remove_reference_t = typename std::remove_reference::type; + + template + inline constexpr T &&forward(remove_reference_t &t) noexcept { + return static_cast(t); + } + + template + inline constexpr T &&forward(remove_reference_t &&t) noexcept { + static_assert(!std::is_lvalue_reference::value, + "can not forward an rvalue as an lvalue"); + return static_cast(t); + } + + template + inline constexpr remove_reference_t &&move(T &&t) noexcept { + return static_cast &&>(t); + } + +#ifdef MPARK_INTEGER_SEQUENCE + using std::integer_sequence; + using std::index_sequence; + using std::make_index_sequence; + using std::index_sequence_for; +#else + template + struct integer_sequence { + using value_type = T; + static constexpr std::size_t size() noexcept { return sizeof...(Is); } + }; + + template + using index_sequence = integer_sequence; + + template + struct make_index_sequence_concat; + + template + struct make_index_sequence_concat, + index_sequence> + : identity> {}; + + template + struct make_index_sequence_impl; + + template + using make_index_sequence = typename make_index_sequence_impl::type; + + template + struct make_index_sequence_impl + : make_index_sequence_concat, + make_index_sequence> {}; + + template <> + struct make_index_sequence_impl<0> : identity> {}; + + template <> + struct make_index_sequence_impl<1> : identity> {}; + + template + using index_sequence_for = make_index_sequence; +#endif + + // +#ifdef MPARK_TRANSPARENT_OPERATORS + using equal_to = std::equal_to<>; +#else + struct equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) == lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using not_equal_to = std::not_equal_to<>; +#else + struct not_equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) != lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using less = std::less<>; +#else + struct less { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) < lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using greater = std::greater<>; +#else + struct greater { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) > lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using less_equal = std::less_equal<>; +#else + struct less_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) <= lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using greater_equal = std::greater_equal<>; +#else + struct greater_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) >= lib::forward(rhs)) + }; +#endif + } // namespace cpp14 + + inline namespace cpp17 { + + // + template + using bool_constant = std::integral_constant; + + template + struct voider : identity {}; + + template + using void_t = typename voider::type; + + namespace detail_ { + namespace swappable { + + using std::swap; + + template + struct is_swappable { + private: + template (), + std::declval()))> + inline static std::true_type test(int); + + template + inline static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; + }; + + template + struct is_nothrow_swappable { + static constexpr bool value = + noexcept(swap(std::declval(), std::declval())); + }; + + template + struct is_nothrow_swappable : std::false_type {}; + + } // namespace swappable + } // namespace detail_ + + using detail_::swappable::is_swappable; + + template + using is_nothrow_swappable = + detail_::swappable::is_nothrow_swappable::value, T>; + + // + namespace detail_ { + + template + struct is_reference_wrapper : std::false_type {}; + + template + struct is_reference_wrapper> + : std::true_type {}; + + template + struct Invoke; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).get().*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN(((*lib::forward(arg)).*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).*pmo) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).get().*pmo) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN((*lib::forward(arg)).*pmo) + }; + + template + inline constexpr auto invoke(R T::*f, Arg &&arg, Args &&... args) + MPARK_RETURN( + Invoke::value, + (std::is_base_of>::value + ? 0 + : is_reference_wrapper>::value + ? 1 + : 2)>::invoke(f, + lib::forward(arg), + lib::forward(args)...)) + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(lib::forward(f)(lib::forward(args)...)) +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } // namespace detail_ + + template + inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(detail_::invoke(lib::forward(f), + lib::forward(args)...)) + + namespace detail_ { + + template + struct invoke_result {}; + + template + struct invoke_result(), std::declval()...))>, + F, + Args...> + : identity(), std::declval()...))> {}; + + } // namespace detail_ + + template + using invoke_result = detail_::invoke_result; + + template + using invoke_result_t = typename invoke_result::type; + + namespace detail_ { + + template + struct is_invocable : std::false_type {}; + + template + struct is_invocable>, F, Args...> + : std::true_type {}; + + template + struct is_invocable_r : std::false_type {}; + + template + struct is_invocable_r>, + R, + F, + Args...> + : std::is_convertible, R> {}; + + } // namespace detail_ + + template + using is_invocable = detail_::is_invocable; + + template + using is_invocable_r = detail_::is_invocable_r; + + namespace detail_ { + + template + struct is_nothrow_invocable { + static constexpr bool value = + noexcept(lib::invoke(std::declval(), std::declval()...)); + }; + + template + struct is_nothrow_invocable : std::false_type {}; + + template + struct is_nothrow_invocable_r { + private: + inline static R impl() { + return lib::invoke(std::declval(), std::declval()...); + } + + public: + static constexpr bool value = noexcept(impl()); + }; + + template + struct is_nothrow_invocable_r : std::false_type {}; + + } // namespace detail_ + + template + using is_nothrow_invocable = detail_:: + is_nothrow_invocable::value, F, Args...>; + + template + using is_nothrow_invocable_r = + detail_::is_nothrow_invocable_r::value, + R, + F, + Args...>; + + // +#ifdef MPARK_BUILTIN_ADDRESSOF + template + inline constexpr T *addressof(T &arg) noexcept { + return __builtin_addressof(arg); + } +#else + namespace detail_ { + + namespace has_addressof_impl { + + struct fail; + + template + inline fail operator&(T &&); + + template + inline static constexpr bool impl() { + return (std::is_class::value || std::is_union::value) && + !std::is_same()), fail>::value; + } + + } // namespace has_addressof_impl + + template + using has_addressof = bool_constant()>; + + template + inline constexpr T *addressof(T &arg, std::true_type) noexcept { + return std::addressof(arg); + } + + template + inline constexpr T *addressof(T &arg, std::false_type) noexcept { + return &arg; + } + + } // namespace detail_ + + template + inline constexpr T *addressof(T &arg) noexcept { + return detail_::addressof(arg, detail_::has_addressof{}); + } +#endif + + template + inline constexpr T *addressof(const T &&) = delete; + + } // namespace cpp17 + + template + struct remove_all_extents : identity {}; + + template + struct remove_all_extents> : remove_all_extents {}; + + template + using remove_all_extents_t = typename remove_all_extents::type; + + template + using size_constant = std::integral_constant; + + template + struct indexed_type : size_constant { using type = T; }; + + template + using all = std::is_same, + integer_sequence>; + +#ifdef MPARK_TYPE_PACK_ELEMENT + template + using type_pack_element_t = __type_pack_element; +#else + template + struct type_pack_element_impl { + private: + template + struct set; + + template + struct set> : indexed_type... {}; + + template + inline static std::enable_if impl(indexed_type); + + inline static std::enable_if impl(...); + + public: + using type = decltype(impl(set>{})); + }; + + template + using type_pack_element = typename type_pack_element_impl::type; + + template + using type_pack_element_t = typename type_pack_element::type; +#endif + +#ifdef MPARK_TRIVIALITY_TYPE_TRAITS + using std::is_trivially_copy_constructible; + using std::is_trivially_move_constructible; + using std::is_trivially_copy_assignable; + using std::is_trivially_move_assignable; +#else + template + struct is_trivially_copy_constructible + : bool_constant< + std::is_copy_constructible::value && __has_trivial_copy(T)> {}; + + template + struct is_trivially_move_constructible : bool_constant<__is_trivial(T)> {}; + + template + struct is_trivially_copy_assignable + : bool_constant< + std::is_copy_assignable::value && __has_trivial_assign(T)> {}; + + template + struct is_trivially_move_assignable : bool_constant<__is_trivial(T)> {}; +#endif + + template + struct dependent_type : T {}; + + template + struct push_back; + + template + using push_back_t = typename push_back::type; + + template + struct push_back, J> { + using type = index_sequence; + }; + + } // namespace lib +} // namespace c10 + +#undef MPARK_RETURN + +#endif // MPARK_LIB_HPP + + +namespace c10 { + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + +#define AUTO auto +#define AUTO_RETURN(...) { return __VA_ARGS__; } + +#define AUTO_REFREF auto && +#define AUTO_REFREF_RETURN(...) { return __VA_ARGS__; } + +#define DECLTYPE_AUTO decltype(auto) +#define DECLTYPE_AUTO_RETURN(...) { return __VA_ARGS__; } + +#else + +#define AUTO auto +#define AUTO_RETURN(...) \ + -> lib::decay_t { return __VA_ARGS__; } + +#define AUTO_REFREF auto +#define AUTO_REFREF_RETURN(...) \ + -> decltype((__VA_ARGS__)) { \ + static_assert(std::is_reference::value, ""); \ + return __VA_ARGS__; \ + } + +#define DECLTYPE_AUTO auto +#define DECLTYPE_AUTO_RETURN(...) \ + -> decltype(__VA_ARGS__) { return __VA_ARGS__; } + +#endif + + class bad_variant_access : public std::exception { + public: + virtual const char *what() const noexcept override { return "bad_variant_access"; } + }; + + [[noreturn]] inline void throw_bad_variant_access() { +#ifdef MPARK_EXCEPTIONS + throw bad_variant_access{}; +#else + std::terminate(); + MPARK_BUILTIN_UNREACHABLE; +#endif + } + + template + class variant; + + template + struct variant_size; + +#ifdef MPARK_VARIABLE_TEMPLATES + template + constexpr std::size_t variant_size_v = variant_size::value; +#endif + + template + struct variant_size : variant_size {}; + + template + struct variant_size : variant_size {}; + + template + struct variant_size : variant_size {}; + + template + struct variant_size> : lib::size_constant {}; + + template + struct variant_alternative; + + template + using variant_alternative_t = typename variant_alternative::type; + + template + struct variant_alternative + : std::add_const> {}; + + template + struct variant_alternative + : std::add_volatile> {}; + + template + struct variant_alternative + : std::add_cv> {}; + + template + struct variant_alternative> { + static_assert(I < sizeof...(Ts), + "index out of bounds in `std::variant_alternative<>`"); + using type = lib::type_pack_element_t; + }; + + constexpr std::size_t variant_npos = static_cast(-1); + + namespace detail_ { + + constexpr std::size_t not_found = static_cast(-1); + constexpr std::size_t ambiguous = static_cast(-2); + +#ifdef MPARK_CPP14_CONSTEXPR + template + inline constexpr std::size_t find_index() { + constexpr lib::array matches = { + {std::is_same::value...} + }; + std::size_t result = not_found; + for (std::size_t i = 0; i < sizeof...(Ts); ++i) { + if (matches[i]) { + if (result != not_found) { + return ambiguous; + } + result = i; + } + } + return result; + } +#else + inline constexpr std::size_t find_index_impl(std::size_t result, + std::size_t) { + return result; + } + + template + inline constexpr std::size_t find_index_impl(std::size_t result, + std::size_t idx, + bool b, + Bs... bs) { + return b ? (result != not_found ? ambiguous + : find_index_impl(idx, idx + 1, bs...)) + : find_index_impl(result, idx + 1, bs...); + } + + template + inline constexpr std::size_t find_index() { + return find_index_impl(not_found, 0, std::is_same::value...); + } +#endif + + template + using find_index_sfinae_impl = + lib::enable_if_t>; + + template + using find_index_sfinae = find_index_sfinae_impl()>; + + template + struct find_index_checked_impl : lib::size_constant { + static_assert(I != not_found, "the specified type is not found."); + static_assert(I != ambiguous, "the specified type is ambiguous."); + }; + + template + using find_index_checked = find_index_checked_impl()>; + + struct valueless_t {}; + + enum class Trait { TriviallyAvailable, Available, Unavailable }; + + template class IsTriviallyAvailable, + template class IsAvailable> + inline constexpr Trait trait() { + return IsTriviallyAvailable::value + ? Trait::TriviallyAvailable + : IsAvailable::value ? Trait::Available + : Trait::Unavailable; + } + +#ifdef MPARK_CPP14_CONSTEXPR + template + inline constexpr Trait common_trait(Traits... traits_) { + Trait result = Trait::TriviallyAvailable; + lib::array traits = {{traits_...}}; + for (std::size_t i = 0; i < sizeof...(Traits); ++i) { + Trait t = traits[i]; + if (static_cast(t) > static_cast(result)) { + result = t; + } + } + return result; + } +#else + inline constexpr Trait common_trait_impl(Trait result) { return result; } + + template + inline constexpr Trait common_trait_impl(Trait result, + Trait t, + Traits... ts) { + return static_cast(t) > static_cast(result) + ? common_trait_impl(t, ts...) + : common_trait_impl(result, ts...); + } + + template + inline constexpr Trait common_trait(Traits... ts) { + return common_trait_impl(Trait::TriviallyAvailable, ts...); + } +#endif + + template + struct traits { + static constexpr Trait copy_constructible_trait = + common_trait(trait()...); + + static constexpr Trait move_constructible_trait = + common_trait(trait()...); + + static constexpr Trait copy_assignable_trait = + common_trait(copy_constructible_trait, + trait()...); + + static constexpr Trait move_assignable_trait = + common_trait(move_constructible_trait, + trait()...); + + static constexpr Trait destructible_trait = + common_trait(trait()...); + }; + + namespace access { + + struct recursive_union { +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t<0>) { + return lib::forward(v).head_; + } + + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t) { + return get_alt(lib::forward(v).tail_, in_place_index_t{}); + } +#else + template + struct get_alt_impl { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v).tail_)) + }; + + template + struct get_alt_impl<0, Dummy> { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(lib::forward(v).head_) + }; + + template + inline static constexpr AUTO_REFREF get_alt(V &&v, in_place_index_t) + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v))) +#endif + }; + + struct base { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) +#ifdef _MSC_VER + AUTO_REFREF_RETURN(recursive_union::get_alt( + lib::forward(v).data_, in_place_index_t{})) +#else + AUTO_REFREF_RETURN(recursive_union::get_alt( + data(lib::forward(v)), in_place_index_t{})) +#endif + }; + + struct variant { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) + AUTO_REFREF_RETURN(base::get_alt(lib::forward(v).impl_)) + }; + + } // namespace access + + namespace visitation { + +#if defined(MPARK_CPP14_CONSTEXPR) && !defined(_MSC_VER) +#define MPARK_VARIANT_SWITCH_VISIT +#endif + + struct base { + template + using dispatch_result_t = decltype( + lib::invoke(std::declval(), + access::base::get_alt<0>(std::declval())...)); + + template + struct expected { + template + inline static constexpr bool but_got() { + return std::is_same::value; + } + }; + + template + struct visit_return_type_check { + static_assert( + expected::template but_got(), + "`visit` requires the visitor to have a single return type"); + + template + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Alts &&... alts) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(alts)...)) + }; + +#ifdef MPARK_VARIANT_SWITCH_VISIT + template + struct dispatcher; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&, typename ITs::type &&..., Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&, Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t, + F &&, + Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + }; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs) { + using Expected = R; + using Actual = decltype(lib::invoke( + lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs, V &&v, Vs &&... vs) { +#define MPARK_DISPATCH(I) \ + dispatcher<(I < lib::decay_t::size()), \ + R, \ + ITs..., \ + lib::indexed_type>:: \ + template dispatch<0>(lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R, ITs...>::template dispatch( \ + lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + + switch (v.index()) { + case B + 0: return MPARK_DISPATCH(B + 0); + case B + 1: return MPARK_DISPATCH(B + 1); + case B + 2: return MPARK_DISPATCH(B + 2); + case B + 3: return MPARK_DISPATCH(B + 3); + case B + 4: return MPARK_DISPATCH(B + 4); + case B + 5: return MPARK_DISPATCH(B + 5); + case B + 6: return MPARK_DISPATCH(B + 6); + case B + 7: return MPARK_DISPATCH(B + 7); + case B + 8: return MPARK_DISPATCH(B + 8); + case B + 9: return MPARK_DISPATCH(B + 9); + case B + 10: return MPARK_DISPATCH(B + 10); + case B + 11: return MPARK_DISPATCH(B + 11); + case B + 12: return MPARK_DISPATCH(B + 12); + case B + 13: return MPARK_DISPATCH(B + 13); + case B + 14: return MPARK_DISPATCH(B + 14); + case B + 15: return MPARK_DISPATCH(B + 15); + case B + 16: return MPARK_DISPATCH(B + 16); + case B + 17: return MPARK_DISPATCH(B + 17); + case B + 18: return MPARK_DISPATCH(B + 18); + case B + 19: return MPARK_DISPATCH(B + 19); + case B + 20: return MPARK_DISPATCH(B + 20); + case B + 21: return MPARK_DISPATCH(B + 21); + case B + 22: return MPARK_DISPATCH(B + 22); + case B + 23: return MPARK_DISPATCH(B + 23); + case B + 24: return MPARK_DISPATCH(B + 24); + case B + 25: return MPARK_DISPATCH(B + 25); + case B + 26: return MPARK_DISPATCH(B + 26); + case B + 27: return MPARK_DISPATCH(B + 27); + case B + 28: return MPARK_DISPATCH(B + 28); + case B + 29: return MPARK_DISPATCH(B + 29); + case B + 30: return MPARK_DISPATCH(B + 30); + case B + 31: return MPARK_DISPATCH(B + 31); + default: return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&f, + Vs &&... vs) { + using Expected = R; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t index, + F &&f, + V &&v, + Vs &&... vs) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); +#define MPARK_DISPATCH_AT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_case( \ + lib::forward(f), lib::forward(v), lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_at( \ + index, lib::forward(f), lib::forward(v), lib::forward(vs)...) + + switch (index) { + case B + 0: return MPARK_DISPATCH_AT(B + 0); + case B + 1: return MPARK_DISPATCH_AT(B + 1); + case B + 2: return MPARK_DISPATCH_AT(B + 2); + case B + 3: return MPARK_DISPATCH_AT(B + 3); + case B + 4: return MPARK_DISPATCH_AT(B + 4); + case B + 5: return MPARK_DISPATCH_AT(B + 5); + case B + 6: return MPARK_DISPATCH_AT(B + 6); + case B + 7: return MPARK_DISPATCH_AT(B + 7); + case B + 8: return MPARK_DISPATCH_AT(B + 8); + case B + 9: return MPARK_DISPATCH_AT(B + 9); + case B + 10: return MPARK_DISPATCH_AT(B + 10); + case B + 11: return MPARK_DISPATCH_AT(B + 11); + case B + 12: return MPARK_DISPATCH_AT(B + 12); + case B + 13: return MPARK_DISPATCH_AT(B + 13); + case B + 14: return MPARK_DISPATCH_AT(B + 14); + case B + 15: return MPARK_DISPATCH_AT(B + 15); + case B + 16: return MPARK_DISPATCH_AT(B + 16); + case B + 17: return MPARK_DISPATCH_AT(B + 17); + case B + 18: return MPARK_DISPATCH_AT(B + 18); + case B + 19: return MPARK_DISPATCH_AT(B + 19); + case B + 20: return MPARK_DISPATCH_AT(B + 20); + case B + 21: return MPARK_DISPATCH_AT(B + 21); + case B + 22: return MPARK_DISPATCH_AT(B + 22); + case B + 23: return MPARK_DISPATCH_AT(B + 23); + case B + 24: return MPARK_DISPATCH_AT(B + 24); + case B + 25: return MPARK_DISPATCH_AT(B + 25); + case B + 26: return MPARK_DISPATCH_AT(B + 26); + case B + 27: return MPARK_DISPATCH_AT(B + 27); + case B + 28: return MPARK_DISPATCH_AT(B + 28); + case B + 29: return MPARK_DISPATCH_AT(B + 29); + case B + 30: return MPARK_DISPATCH_AT(B + 30); + case B + 31: return MPARK_DISPATCH_AT(B + 31); + default: return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH_AT + } + }; +#else + template + inline static constexpr const T &at(const T &elem) noexcept { + return elem; + } + + template + inline static constexpr const lib::remove_all_extents_t &at( + const lib::array &elems, std::size_t i, Is... is) noexcept { + return at(elems[i], is...); + } + + template + inline static constexpr lib::array, sizeof...(Fs) + 1> + make_farray(F &&f, Fs &&... fs) { + return {{lib::forward(f), lib::forward(fs)...}}; + } + + template + struct make_fmatrix_impl { + + template + inline static constexpr dispatch_result_t dispatch( + F &&f, Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype(lib::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto impl(lib::index_sequence) { + return &dispatch; + } + + template + inline static constexpr auto impl(Is, + lib::index_sequence, + Ls... ls) { + return make_farray(impl(lib::push_back_t{}, ls...)...); + } +#else + template + struct impl; + + template + struct impl> { + inline constexpr AUTO operator()() const + AUTO_RETURN(&dispatch) + }; + + template + struct impl, Ls...> { + inline constexpr AUTO operator()() const + AUTO_RETURN( + make_farray(impl, Ls...>{}()...)) + }; +#endif + }; + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto make_fmatrix() { + return make_fmatrix_impl::impl( + lib::index_sequence<>{}, + lib::make_index_sequence::size()>{}...); + } +#else + template + inline static constexpr AUTO make_fmatrix() + AUTO_RETURN( + typename make_fmatrix_impl::template impl< + lib::index_sequence<>, + lib::make_index_sequence::size()>...>{}()) +#endif + + template + struct make_fdiagonal_impl { + template + inline static constexpr dispatch_result_t dispatch( + F &&f, Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + inline static constexpr AUTO impl(lib::index_sequence) + AUTO_RETURN(make_farray(&dispatch...)) + }; + + template + inline static constexpr auto make_fdiagonal() + -> decltype(make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{})) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); + return make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{}); + } +#endif + }; + +#if !defined(MPARK_VARIANT_SWITCH_VISIT) && \ + (!defined(_MSC_VER) || _MSC_VER >= 1910) + template + using fmatrix_t = decltype(base::make_fmatrix()); + + template + struct fmatrix { + static constexpr fmatrix_t value = + base::make_fmatrix(); + }; + + template + constexpr fmatrix_t fmatrix::value; + + template + using fdiagonal_t = decltype(base::make_fdiagonal()); + + template + struct fdiagonal { + static constexpr fdiagonal_t value = + base::make_fdiagonal(); + }; + + template + constexpr fdiagonal_t fdiagonal::value; +#endif + + struct alt { + template + inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, + Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher< + true, + base::dispatch_result_t(vs)))...>>:: + template dispatch<0>(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN(base::at( + fmatrix(vs)))...>::value, + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN(base::at( + base::make_fmatrix(vs)))...>(), + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif + + template + inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher< + true, + base::dispatch_result_t(vs)))...>>:: + template dispatch_at<0>(index, + lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN(base::at( + fdiagonal(vs)))...>::value, + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN(base::at( + base::make_fdiagonal(vs)))...>(), + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif + }; + + struct variant { + private: + template + struct visitor { + template + inline static constexpr bool does_not_handle() { + return lib::is_invocable::value; + } + }; + + template + struct visit_exhaustiveness_check { + static_assert(visitor::template does_not_handle(), + "`visit` requires the visitor to be exhaustive."); + + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Values &&... values) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(values)...)) + }; + + template + struct value_visitor { + Visitor &&visitor_; + + template + inline constexpr DECLTYPE_AUTO operator()(Alts &&... alts) const + DECLTYPE_AUTO_RETURN( + visit_exhaustiveness_check< + Visitor, + decltype((lib::forward(alts).value))...>:: + invoke(lib::forward(visitor_), + lib::forward(alts).value...)) + }; + + template + inline static constexpr AUTO make_value_visitor(Visitor &&visitor) + AUTO_RETURN(value_visitor{lib::forward(visitor)}) + + public: + template + inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN(alt::visit_alt(lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + alt::visit_alt_at(index, + lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO visit_value(Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + visit_alt(make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) + + template + inline static constexpr DECLTYPE_AUTO visit_value_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + visit_alt_at(index, + make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) + }; + + } // namespace visitation + + template + struct alt { + using value_type = T; + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + template + inline explicit constexpr alt(variant_in_place_t, Args &&... args) + : value(lib::forward(args)...) {} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + T value; + }; + + template + union recursive_union; + + template + union recursive_union {}; + +#define MPARK_VARIANT_RECURSIVE_UNION(destructible_trait, destructor) \ + template \ + union recursive_union { \ + public: \ + inline explicit constexpr recursive_union(valueless_t) noexcept \ + : dummy_{} {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t<0>, \ + Args &&... args) \ + : head_(variant_in_place_t{}, lib::forward(args)...) {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t, \ + Args &&... args) \ + : tail_(in_place_index_t{}, lib::forward(args)...) {} \ + \ + recursive_union(const recursive_union &) = default; \ + recursive_union(recursive_union &&) = default; \ + \ + destructor \ + \ + recursive_union &operator=(const recursive_union &) = default; \ + recursive_union &operator=(recursive_union &&) = default; \ + \ + private: \ + char dummy_; \ + alt head_; \ + recursive_union tail_; \ + \ + friend struct access::recursive_union; \ + } + + MPARK_VARIANT_RECURSIVE_UNION(Trait::TriviallyAvailable, + ~recursive_union() = default;); + MPARK_VARIANT_RECURSIVE_UNION(Trait::Available, + ~recursive_union() {}); + MPARK_VARIANT_RECURSIVE_UNION(Trait::Unavailable, + ~recursive_union() = delete;); + +#undef MPARK_VARIANT_RECURSIVE_UNION + + using index_t = unsigned int; + + template + class base { + public: + inline explicit constexpr base(valueless_t tag) noexcept + : data_(tag), index_(static_cast(-1)) {} + + template + inline explicit constexpr base(in_place_index_t, Args &&... args) + : data_(in_place_index_t{}, lib::forward(args)...), + index_(I) {} + + inline constexpr bool valueless_by_exception() const noexcept { + return index_ == static_cast(-1); + } + + inline constexpr std::size_t index() const noexcept { + return valueless_by_exception() ? variant_npos : index_; + } + + protected: + using data_t = recursive_union; + + friend inline constexpr base &as_base(base &b) { return b; } + friend inline constexpr const base &as_base(const base &b) { return b; } + friend inline constexpr base &&as_base(base &&b) { return lib::move(b); } + friend inline constexpr const base &&as_base(const base &&b) { return lib::move(b); } + + friend inline constexpr data_t &data(base &b) { return b.data_; } + friend inline constexpr const data_t &data(const base &b) { return b.data_; } + friend inline constexpr data_t &&data(base &&b) { return lib::move(b).data_; } + friend inline constexpr const data_t &&data(const base &&b) { return lib::move(b).data_; } + + inline static constexpr std::size_t size() { return sizeof...(Ts); } + + data_t data_; + index_t index_; + + friend struct access::base; + friend struct visitation::base; + }; + + struct dtor { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline void operator()(Alt &alt) const noexcept { alt.~Alt(); } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }; + +#if !defined(_MSC_VER) || _MSC_VER >= 1910 +#define MPARK_INHERITING_CTOR(type, base) using base::base; +#else +#define MPARK_INHERITING_CTOR(type, base) \ + template \ + inline explicit constexpr type(Args &&... args) \ + : base(lib::forward(args)...) {} +#endif + + template + class destructor; + +#define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ + template \ + class destructor, destructible_trait> \ + : public base { \ + using super = base; \ + \ + public: \ + MPARK_INHERITING_CTOR(destructor, super) \ + using super::operator=; \ + \ + destructor(const destructor &) = default; \ + destructor(destructor &&) = default; \ + definition \ + destructor &operator=(const destructor &) = default; \ + destructor &operator=(destructor &&) = default; \ + \ + protected: \ + destroy \ + } + + MPARK_VARIANT_DESTRUCTOR( + Trait::TriviallyAvailable, + ~destructor() = default;, + inline void destroy() noexcept { + this->index_ = static_cast(-1); + }); + + MPARK_VARIANT_DESTRUCTOR( + Trait::Available, + ~destructor() { destroy(); }, + inline void destroy() noexcept { + if (!this->valueless_by_exception()) { + visitation::alt::visit_alt(dtor{}, *this); + } + this->index_ = static_cast(-1); + }); + + MPARK_VARIANT_DESTRUCTOR( + Trait::Unavailable, + ~destructor() = delete;, + inline void destroy() noexcept = delete;); + +#undef MPARK_VARIANT_DESTRUCTOR + + template + class constructor : public destructor { + using super = destructor; + + public: + MPARK_INHERITING_CTOR(constructor, super) + using super::operator=; + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + struct ctor { + template + inline void operator()(LhsAlt &lhs_alt, RhsAlt &&rhs_alt) const { + constructor::construct_alt(lhs_alt, + lib::forward(rhs_alt).value); + } + }; +#endif + + template + inline static T &construct_alt(alt &a, Args &&... args) { + auto *result = ::new (static_cast(lib::addressof(a))) + alt(variant_in_place_t{}, lib::forward(args)...); + return result->value; + } + + template + inline static void generic_construct(constructor &lhs, Rhs &&rhs) { + lhs.destroy(); + if (!rhs.valueless_by_exception()) { + visitation::alt::visit_alt_at( + rhs.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &lhs_alt, auto &&rhs_alt) { + constructor::construct_alt( + lhs_alt, lib::forward(rhs_alt).value); + } +#else + ctor{} +#endif + , + lhs, + lib::forward(rhs)); + lhs.index_ = rhs.index_; + } + } + }; + + template + class move_constructor; + +#define MPARK_VARIANT_MOVE_CONSTRUCTOR(move_constructible_trait, definition) \ + template \ + class move_constructor, move_constructible_trait> \ + : public constructor> { \ + using super = constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_constructor, super) \ + using super::operator=; \ + \ + move_constructor(const move_constructor &) = default; \ + definition \ + ~move_constructor() = default; \ + move_constructor &operator=(const move_constructor &) = default; \ + move_constructor &operator=(move_constructor &&) = default; \ + } + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::TriviallyAvailable, + move_constructor(move_constructor &&that) = default;); + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::Available, + move_constructor(move_constructor &&that) noexcept( + lib::all::value...>::value) + : move_constructor(valueless_t{}) { + this->generic_construct(*this, lib::move(that)); + }); + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::Unavailable, + move_constructor(move_constructor &&) = delete;); + +#undef MPARK_VARIANT_MOVE_CONSTRUCTOR + + template + class copy_constructor; + +#define MPARK_VARIANT_COPY_CONSTRUCTOR(copy_constructible_trait, definition) \ + template \ + class copy_constructor, copy_constructible_trait> \ + : public move_constructor> { \ + using super = move_constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_constructor, super) \ + using super::operator=; \ + \ + definition \ + copy_constructor(copy_constructor &&) = default; \ + ~copy_constructor() = default; \ + copy_constructor &operator=(const copy_constructor &) = default; \ + copy_constructor &operator=(copy_constructor &&) = default; \ + } + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::TriviallyAvailable, + copy_constructor(const copy_constructor &that) = default;); + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::Available, + copy_constructor(const copy_constructor &that) + : copy_constructor(valueless_t{}) { + this->generic_construct(*this, that); + }); + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::Unavailable, + copy_constructor(const copy_constructor &) = delete;); + +#undef MPARK_VARIANT_COPY_CONSTRUCTOR + + template + class assignment : public copy_constructor { + using super = copy_constructor; + + public: + MPARK_INHERITING_CTOR(assignment, super) + using super::operator=; + + template + inline /* auto & */ auto emplace(Args &&... args) + -> decltype(this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...)) { + this->destroy(); + auto &result = this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...); + this->index_ = I; + return result; + } + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + template + struct assigner { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &&that_alt) const { + self->assign_alt(this_alt, lib::forward(that_alt).value); + } + assignment *self; + }; +#endif + + template + inline void assign_alt(alt &a, Arg &&arg) { + if (this->index() == I) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + a.value = lib::forward(arg); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } else { + struct { + void operator()(std::true_type) const { + this_->emplace(lib::forward(arg_)); + } + void operator()(std::false_type) const { + this_->emplace(T(lib::forward(arg_))); + } + assignment *this_; + Arg &&arg_; + } impl{this, lib::forward(arg)}; + impl(lib::bool_constant< + std::is_nothrow_constructible::value || + !std::is_nothrow_move_constructible::value>{}); + } + } + + template + inline void generic_assign(That &&that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (that.valueless_by_exception()) { + this->destroy(); + } else { + visitation::alt::visit_alt_at( + that.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [this](auto &this_alt, auto &&that_alt) { + this->assign_alt( + this_alt, lib::forward(that_alt).value); + } +#else + assigner{this} +#endif + , + *this, + lib::forward(that)); + } + } + }; + + template + class move_assignment; + +#define MPARK_VARIANT_MOVE_ASSIGNMENT(move_assignable_trait, definition) \ + template \ + class move_assignment, move_assignable_trait> \ + : public assignment> { \ + using super = assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_assignment, super) \ + using super::operator=; \ + \ + move_assignment(const move_assignment &) = default; \ + move_assignment(move_assignment &&) = default; \ + ~move_assignment() = default; \ + move_assignment &operator=(const move_assignment &) = default; \ + definition \ + } + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::TriviallyAvailable, + move_assignment &operator=(move_assignment &&that) = default;); + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Available, + move_assignment & + operator=(move_assignment &&that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + std::is_nothrow_move_assignable::value)...>::value) { + this->generic_assign(lib::move(that)); + return *this; + }); + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Unavailable, + move_assignment &operator=(move_assignment &&) = delete;); + +#undef MPARK_VARIANT_MOVE_ASSIGNMENT + + template + class copy_assignment; + +#define MPARK_VARIANT_COPY_ASSIGNMENT(copy_assignable_trait, definition) \ + template \ + class copy_assignment, copy_assignable_trait> \ + : public move_assignment> { \ + using super = move_assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_assignment, super) \ + using super::operator=; \ + \ + copy_assignment(const copy_assignment &) = default; \ + copy_assignment(copy_assignment &&) = default; \ + ~copy_assignment() = default; \ + definition \ + copy_assignment &operator=(copy_assignment &&) = default; \ + } + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::TriviallyAvailable, + copy_assignment &operator=(const copy_assignment &that) = default;); + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Available, + copy_assignment &operator=(const copy_assignment &that) { + this->generic_assign(that); + return *this; + }); + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Unavailable, + copy_assignment &operator=(const copy_assignment &) = delete;); + +#undef MPARK_VARIANT_COPY_ASSIGNMENT + + template + class impl : public copy_assignment> { + using super = copy_assignment>; + + public: + MPARK_INHERITING_CTOR(impl, super) + using super::operator=; + + template + inline void assign(Arg &&arg) { + this->assign_alt(access::base::get_alt(*this), + lib::forward(arg)); + } + + inline void swap(impl &that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (this->index() == that.index()) { + visitation::alt::visit_alt_at(this->index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &this_alt, auto &that_alt) { + using std::swap; + swap(this_alt.value, + that_alt.value); + } +#else + swapper{} +#endif + , + *this, + that); + } else { + impl *lhs = this; + impl *rhs = lib::addressof(that); + if (lhs->move_nothrow() && !rhs->move_nothrow()) { + std::swap(lhs, rhs); + } + impl tmp(lib::move(*rhs)); +#ifdef MPARK_EXCEPTIONS + // EXTENSION: When the move construction of `lhs` into `rhs` throws + // and `tmp` is nothrow move constructible then we move `tmp` back + // into `rhs` and provide the strong exception safety guarantee. + try { + this->generic_construct(*rhs, lib::move(*lhs)); + } catch (...) { + if (tmp.move_nothrow()) { + this->generic_construct(*rhs, lib::move(tmp)); + } + throw; + } +#else + this->generic_construct(*rhs, lib::move(*lhs)); +#endif + this->generic_construct(*lhs, lib::move(tmp)); + } + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct swapper { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &that_alt) const { + using std::swap; + swap(this_alt.value, that_alt.value); + } + }; +#endif + + inline constexpr bool move_nothrow() const { + return this->valueless_by_exception() || + lib::array{ + {std::is_nothrow_move_constructible::value...} + }[this->index()]; + } + }; + +#undef MPARK_INHERITING_CTOR + + template + struct overload_leaf { + using F = lib::size_constant (*)(T); + operator F() const { return nullptr; } + }; + + template + struct overload_impl { + private: + template + struct impl; + + template + struct impl> : overload_leaf... {}; + + public: + using type = impl>; + }; + + template + using overload = typename overload_impl::type; + + template + using best_match = lib::invoke_result_t, T &&>; + + template + struct is_in_place_index : std::false_type {}; + + template + struct is_in_place_index> : std::true_type {}; + + template + struct is_in_place_type : std::false_type {}; + + template + struct is_in_place_type> : std::true_type {}; + + } // detail_ + + template + class variant { + static_assert(0 < sizeof...(Ts), + "variant must consist of at least one alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have an array type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a reference type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a void type as an alternative."); + + public: + template < + typename Front = lib::type_pack_element_t<0, Ts...>, + lib::enable_if_t::value, int> = 0> + inline constexpr variant() noexcept( + std::is_nothrow_default_constructible::value) + : impl_(in_place_index_t<0>{}) {} + + variant(const variant &) = default; + variant(variant &&) = default; + + // NOTE [gcc 7.3.1 bug workaround] + // + // The original line `typename T = lib::type_pack_element_t` + // throws the following compiler error on gcc 7.3.1: + // ``` + // ../c10/util/variant.h:2250:9: internal compiler error: + // unexpected expression ‘I’ of kind template_parm_index + // typename T = lib::type_pack_element_t, + // ^~~~~~~~ + // ``` + // As a workaround, `I` is changed to `detail_::best_match::value`, + // which is the default value for `I` in this template. Note that this workaround + // effectively disallows setting `I` to any other non-default value, and we add a + // `static_assert` in the function body to check for this. + // + // See the following issues for more context: + // - https://github.com/mpark/variant/issues/43 + // - https://github.com/eggs-cpp/variant/issues/31 + template < + typename Arg, + typename Decayed = lib::decay_t, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + std::size_t I = detail_::best_match::value, + typename T = lib::type_pack_element_t::value, Ts...>, + lib::enable_if_t::value, int> = 0> + inline constexpr variant(Arg &&arg) noexcept( + std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(arg)) { + static_assert( + I == detail_::best_match::value, + "Setting template parameter `I` to a custom non-default value is not supported. " + "Please file a feature request if you see this."); + } + + template < + std::size_t I, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_index_t, + Args &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_index_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + template < + typename T, + typename... Args, + std::size_t I = detail_::find_index_sfinae::value, + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_type_t, + Args &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail_::find_index_sfinae::value, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_type_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + ~variant() = default; + + variant &operator=(const variant &) = default; + variant &operator=(variant &&) = default; + + // NOTE: See NOTE [gcc 7.3.1 bug workaround] for the changes made to this function. + template , variant>::value, + int> = 0, + std::size_t I = detail_::best_match::value, + typename T = lib::type_pack_element_t::value, Ts...>, + lib::enable_if_t<(std::is_assignable::value && + std::is_constructible::value), + int> = 0> + inline variant &operator=(Arg &&arg) noexcept( + (std::is_nothrow_assignable::value && + std::is_nothrow_constructible::value)) { + static_assert( + I == detail_::best_match::value, + "Setting template parameter `I` to a custom non-default value is not supported. " + "Please file a feature request if you see this."); + impl_.template assign(lib::forward(arg)); + return *this; + } + + template < + std::size_t I, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + template < + typename T, + typename... Args, + std::size_t I = detail_::find_index_sfinae::value, + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail_::find_index_sfinae::value, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + inline constexpr bool valueless_by_exception() const noexcept { + return impl_.valueless_by_exception(); + } + + inline constexpr std::size_t index() const noexcept { + return impl_.index(); + } + + template , + Dummy>::value && + lib::dependent_type, + Dummy>::value)...>::value, + int> = 0> + inline void swap(variant &that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + lib::is_nothrow_swappable::value)...>::value) { + impl_.swap(that.impl_); + } + + private: + detail_::impl impl_; + + friend struct detail_::access::variant; + friend struct detail_::visitation::variant; + }; + + template + inline constexpr bool holds_alternative(const variant &v) noexcept { + return v.index() == I; + } + + template + inline constexpr bool holds_alternative(const variant &v) noexcept { + return holds_alternative::value>(v); + } + + namespace detail_ { + template + struct generic_get_impl { + constexpr generic_get_impl(int) noexcept {} + + constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN( + access::variant::get_alt(lib::forward(v)).value) + }; + + template + inline constexpr AUTO_REFREF generic_get(V &&v) + AUTO_REFREF_RETURN(generic_get_impl( + holds_alternative(v) ? 0 : (throw_bad_variant_access(), 0))( + lib::forward(v))) + } // namespace detail_ + + template + inline constexpr variant_alternative_t> &get( + variant &v) { + return detail_::generic_get(v); + } + + template + inline constexpr variant_alternative_t> &&get( + variant &&v) { + return detail_::generic_get(lib::move(v)); + } + + template + inline constexpr const variant_alternative_t> &get( + const variant &v) { + return detail_::generic_get(v); + } + + template + inline constexpr const variant_alternative_t> &&get( + const variant &&v) { + return detail_::generic_get(lib::move(v)); + } + + template + inline constexpr T &get(variant &v) { + return get::value>(v); + } + + template + inline constexpr T &&get(variant &&v) { + return get::value>(lib::move(v)); + } + + template + inline constexpr const T &get(const variant &v) { + return get::value>(v); + } + + template + inline constexpr const T &&get(const variant &&v) { + return get::value>(lib::move(v)); + } + + namespace detail_ { + + template + inline constexpr /* auto * */ AUTO generic_get_if(V *v) noexcept + AUTO_RETURN(v && holds_alternative(*v) + ? lib::addressof(access::variant::get_alt(*v).value) + : nullptr) + + } // namespace detail_ + + template + inline constexpr lib::add_pointer_t>> + get_if(variant *v) noexcept { + return detail_::generic_get_if(v); + } + + template + inline constexpr lib::add_pointer_t< + const variant_alternative_t>> + get_if(const variant *v) noexcept { + return detail_::generic_get_if(v); + } + + template + inline constexpr lib::add_pointer_t + get_if(variant *v) noexcept { + return get_if::value>(v); + } + + template + inline constexpr lib::add_pointer_t + get_if(const variant *v) noexcept { + return get_if::value>(v); + } + + namespace detail_ { + template + struct convert_to_bool { + template + inline constexpr bool operator()(Lhs &&lhs, Rhs &&rhs) const { + static_assert(std::is_convertible, + bool>::value, + "relational operators must return a type" + " implicitly convertible to bool"); + return lib::invoke( + RelOp{}, lib::forward(lhs), lib::forward(rhs)); + } + }; + } // namespace detail_ + + template + inline constexpr bool operator==(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using equal_to = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return false; + if (lhs.valueless_by_exception()) return true; + return variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs); +#else + return lhs.index() == rhs.index() && + (lhs.valueless_by_exception() || + variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs)); +#endif + } + + template + inline constexpr bool operator!=(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using not_equal_to = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return true; + if (lhs.valueless_by_exception()) return false; + return variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs); +#else + return lhs.index() != rhs.index() || + (!lhs.valueless_by_exception() && + variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs)); +#endif + } + + template + inline constexpr bool operator<(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using less = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return false; + if (lhs.valueless_by_exception()) return true; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less{}, lhs, rhs); +#else + return !rhs.valueless_by_exception() && + (lhs.valueless_by_exception() || lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less{}, lhs, rhs))); +#endif + } + + template + inline constexpr bool operator>(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using greater = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return false; + if (rhs.valueless_by_exception()) return true; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater{}, lhs, rhs); +#else + return !lhs.valueless_by_exception() && + (rhs.valueless_by_exception() || lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), greater{}, lhs, rhs))); +#endif + } + + template + inline constexpr bool operator<=(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using less_equal = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return true; + if (rhs.valueless_by_exception()) return false; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs); +#else + return lhs.valueless_by_exception() || + (!rhs.valueless_by_exception() && + (lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs)))); +#endif + } + + template + inline constexpr bool operator>=(const variant &lhs, + const variant &rhs) { + using detail_::visitation::variant; + using greater_equal = detail_::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return true; + if (lhs.valueless_by_exception()) return false; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater_equal{}, lhs, rhs); +#else + return rhs.valueless_by_exception() || + (!lhs.valueless_by_exception() && + (lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at( + lhs.index(), greater_equal{}, lhs, rhs)))); +#endif + } + + struct monostate {}; + + inline constexpr bool operator<(monostate, monostate) noexcept { + return false; + } + + inline constexpr bool operator>(monostate, monostate) noexcept { + return false; + } + + inline constexpr bool operator<=(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator>=(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator==(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator!=(monostate, monostate) noexcept { + return false; + } + +#ifdef MPARK_CPP14_CONSTEXPR + namespace detail_ { + + inline constexpr bool all(std::initializer_list bs) { + for (bool b : bs) { + if (!b) { + return false; + } + } + return true; + } + + } // namespace detail_ + + template + inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&... vs) { + return (detail_::all({!vs.valueless_by_exception()...}) + ? (void)0 + : throw_bad_variant_access()), + detail_::visitation::variant::visit_value( + lib::forward(visitor), lib::forward(vs)...); + } +#else + namespace detail_ { + + template + inline constexpr bool all_impl(const lib::array &bs, + std::size_t idx) { + return idx >= N || (bs[idx] && all_impl(bs, idx + 1)); + } + + template + inline constexpr bool all(const lib::array &bs) { + return all_impl(bs, 0); + } + + } // namespace detail_ + + template + inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN( + (detail_::all( + lib::array{{!vs.valueless_by_exception()...}}) + ? (void)0 + : throw_bad_variant_access()), + detail_::visitation::variant::visit_value(lib::forward(visitor), + lib::forward(vs)...)) +#endif + + template + inline auto swap(variant &lhs, + variant &rhs) noexcept(noexcept(lhs.swap(rhs))) + -> decltype(lhs.swap(rhs)) { + lhs.swap(rhs); + } + + namespace detail_ { + + template + using enabled_type = T; + + namespace hash { + + template + constexpr bool meets_requirements() noexcept { + return std::is_copy_constructible::value && + std::is_move_constructible::value && + lib::is_invocable_r::value; + } + + template + constexpr bool is_enabled() noexcept { + using H = std::hash; + return meets_requirements() && + std::is_default_constructible::value && + std::is_copy_assignable::value && + std::is_move_assignable::value; + } + + } // namespace hash + + } // namespace detail_ + +#undef AUTO +#undef AUTO_RETURN + +#undef AUTO_REFREF +#undef AUTO_REFREF_RETURN + +#undef DECLTYPE_AUTO +#undef DECLTYPE_AUTO_RETURN + +} // namespace c10 + +namespace std { + + template + struct hash, + c10::lib::enable_if_t>()...>::value>>> { + using argument_type = c10::variant; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &v) const { + using c10::detail_::visitation::variant; + std::size_t result = + v.valueless_by_exception() + ? 299792458 // Random value chosen by the universe upon creation + : variant::visit_alt( +#ifdef MPARK_GENERIC_LAMBDAS + [](const auto &alt) { + using alt_type = c10::lib::decay_t; + using value_type = c10::lib::remove_const_t< + typename alt_type::value_type>; + return hash{}(alt.value); + } +#else + hasher{} +#endif + , + v); + return hash_combine(result, hash{}(v.index())); + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct hasher { + template + inline std::size_t operator()(const Alt &alt) const { + using alt_type = c10::lib::decay_t; + using value_type = + c10::lib::remove_const_t; + return hash{}(alt.value); + } + }; +#endif + + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + } + }; + + template <> + struct hash { + using argument_type = c10::monostate; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &) const noexcept { + return 66740831; // return a fundamentally attractive random value. + } + }; + +} // namespace std + +#endif // C10_UTIL_VARIANT_H_ diff --git a/docs/cpp/source/conf.py b/docs/cpp/source/conf.py index d78da06f9da..47383665306 100644 --- a/docs/cpp/source/conf.py +++ b/docs/cpp/source/conf.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # PyTorch documentation build configuration file, created by diff --git a/docs/source/conf.py b/docs/source/conf.py index a0bc0ddfeea..b6ea07f75f1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # PyTorch documentation build configuration file, created by diff --git a/docs/source/index.rst b/docs/source/index.rst index 0779ba5dae1..91d9d0e0689 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -53,6 +53,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. torch.utils.model_zoo torch.utils.tensorboard type_info + named_tensor + name_inference torch.__config__ <__config__> .. toctree:: diff --git a/docs/source/jit.rst b/docs/source/jit.rst index f962b83c0e8..3a4625135bc 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -20,7 +20,7 @@ process and loaded in a process where there is no Python dependency. We provide tools to incrementally transition a model from a pure Python program to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export -the model via TorchScript to a production environment where Python programs may be disadvantageous. +the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons. For a gentle introduction to TorchScript, see the `Introduction to TorchScript `_ tutorial. @@ -34,6 +34,9 @@ Creating TorchScript Code .. autoclass:: ScriptModule() :members: + +.. autoclass:: ScriptFunction() + .. autofunction:: script(obj) .. autofunction:: trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5) @@ -154,9 +157,9 @@ methods, and classes that it encounters. Once you call ``torch.jit.script``, compilation is "opt-out", rather than "opt-in". 2. ``torch.jit.script(nn_module_instance)`` is now the preferred way to create -``ScriptModule``\s, instead of inheriting from ``torch.jit.ScriptModule``. +:class:`ScriptModule`\s, instead of inheriting from ``torch.jit.ScriptModule``. These changes combine to provide a simpler, easier-to-use API for converting -your ``nn.Module``\s into ``ScriptModule``\s, ready to be optimized and executed in a +your ``nn.Module``\s into :class:`ScriptModule`\s, ready to be optimized and executed in a non-Python environment. The new usage looks like this: @@ -207,7 +210,7 @@ Modules and :func:`@torch.jit.unused` for details. When passed to the :func:`torch.jit.script ` function, a ``torch.nn.Module``\'s data is -copied to a ``ScriptModule`` and the TorchScript compiler compiles the module. +copied to a :class:`ScriptModule` and the TorchScript compiler compiles the module. The module's ``forward`` is compiled by default. Methods called from ``forward`` are lazily compiled in the order they are used in ``forward``, as well as any ``@torch.jit.export`` methods. @@ -248,6 +251,9 @@ Attributes The TorchScript compiler needs to know the types of `module attributes`_. Most types can be inferred from the value of the member. Empty lists and dicts cannot have their types inferred and must have their types annotated with `PEP 526-style `_ class annotations. +If a type cannot be inferred and is not explicilty annotated, it will not be added as an attribute +to the resulting :class:`ScriptModule` + Old API: @@ -304,7 +310,7 @@ If you are stuck on Python 2 and cannot use the class annotation syntax, you can Constants ~~~~~~~~~ -The ``Final`` type constructor can be used to mark members as `constant`_. If members are not marked constant, they will be copied to the resulting ``ScriptModule`` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety. +The ``Final`` type constructor can be used to mark members as `constant`_. If members are not marked constant, they will be copied to the resulting :class:`ScriptModule` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety. Old API: @@ -1187,9 +1193,11 @@ The ``torch.nn.Parameter`` wrapper and ``register_buffer`` can be used to assign tensors to a module. Other values assigned to a module that is compiled will be added to the compiled module if their types can be inferred. All `types`_ available in TorchScript can be used as module attributes. Tensor attributes are -semantically the same as buffers. The type of empty containers and ``None`` +semantically the same as buffers. The type of empty lists and dictionaries and ``None`` values cannot be inferred and must be specified via `PEP 526-style `_ class annotations. +If a type cannot be inferred and is not explicilty annotated, it will not be added as an attribute +to the resulting :class:`ScriptModule`. Example: @@ -1198,7 +1206,7 @@ Example: from typing import List, Dict class Foo(nn.Module): - # `words` is initialzed as an empty list, so its type must be specified + # `words` is initialized as an empty list, so its type must be specified words: List[str] # The type could potentially be inferred if `a_dict` (below) was not @@ -1284,13 +1292,13 @@ Disable JIT for Debugging traced_fn(torch.rand(3, 4)) Debugging this script with ``pdb`` works except for when we invoke the :func:`@torch.jit.script ` - function. We can globally disable JIT, so that we can call the ``@torch.jit.script`` + function. We can globally disable JIT, so that we can call the :func:`@torch.jit.script ` function as a normal Python function and not compile it. If the above script is called ``disable_jit_example.py``, we can invoke it like so:: $ PYTORCH_JIT=0 python disable_jit_example.py - and we will be able to step into the ``@torch.jit.script`` function as a normal Python + and we will be able to step into the :func:`@torch.jit.script ` function as a normal Python function. To disable the TorchScript compiler for a specific function, see :func:`@torch.jit.ignore `. @@ -1298,7 +1306,7 @@ Disable JIT for Debugging Inspecting Code ^^^^^^^^^^^^^^^ -TorchScript provides a code pretty-printer for all ``ScriptModule`` instances. This +TorchScript provides a code pretty-printer for all :class:`ScriptModule` instances. This pretty-printer gives an interpretation of the script method's code as valid Python syntax. For example: @@ -1322,11 +1330,11 @@ Python syntax. For example: ... -A ``ScriptModule`` with a single ``forward`` method will have an attribute -``code``, which you can use to inspect the ``ScriptModule``'s code. -If the ``ScriptModule`` has more than one method, you will need to access +A :class:`ScriptModule` with a single ``forward`` method will have an attribute +``code``, which you can use to inspect the :class:`ScriptModule`'s code. +If the :class:`ScriptModule` has more than one method, you will need to access ``.code`` on the method itself and not the module. We can inspect the -code of a method named ``bar`` on a ScriptModule by accessing ``.bar.code``. +code of a method named ``foo`` on a ScriptModule by accessing ``.foo.code``. The example above produces this output: :: def foo(len: int) -> Tensor: @@ -1419,7 +1427,7 @@ operators are formatted to reflect their equivalent source code forms to facilitate easy debugging. Graphs can be inspected as shown to confirm that the computation described -by a ``ScriptModule`` is correct, in both automated and manual fashion, as +by a :class:`ScriptModule` is correct, in both automated and manual fashion, as described below. @@ -1638,7 +1646,7 @@ best practices? the correct device information. -Q: How do I store attributes on a ``ScriptModule``? +Q: How do I store attributes on a :class:`ScriptModule`? Say we have a model like: @@ -1658,7 +1666,7 @@ Q: How do I store attributes on a ``ScriptModule``? If ``Model`` is instantiated it will result in a compilation error since the compiler doesn't know about ``x``. There are 4 ways to inform the - compiler of attributes on ``ScriptModule``: + compiler of attributes on :class:`ScriptModule`: 1. ``nn.Parameter`` - Values wrapped in ``nn.Parameter`` will work as they do on ``nn.Module``\s diff --git a/docs/source/name_inference.rst b/docs/source/name_inference.rst new file mode 100644 index 00000000000..5588a863344 --- /dev/null +++ b/docs/source/name_inference.rst @@ -0,0 +1,468 @@ +.. currentmodule:: torch + +.. _name_inference_reference-doc: + +Named Tensors operator coverage +=============================== + +Please read :ref:`named_tensors-doc` first for an introduction to named tensors. + +This document is a reference for *name inference*, a process that defines how +named tensors: + +1. use names to provide additional automatic runtime correctness checks +2. propagate names from input tensors to output tensors + +Below is a list of all operations that are supported with named tensors +and their associated name inference rules. + +If you don't see an operation listed here, but it would help your use case, please +`search if an issue has already been filed `_ and if not, `file one `_. + +.. warning:: + The named tensor API is experimental and subject to change. + +.. csv-table:: Supported Operations + :header: API, Name inference rule + :widths: 20, 20 + + ":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc` + :meth:`Tensor.abs_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc` + :meth:`Tensor.acos_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc` + ":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc` + :meth:`Tensor.align_as`,See documentation + :meth:`Tensor.align_to`,See documentation + ":meth:`Tensor.all`, :func:`torch.all`",None + ":meth:`Tensor.any`, :func:`torch.any`",None + ":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc` + :meth:`Tensor.asin_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.atan_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc` + :meth:`Tensor.bernoulli_`,None + :meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc` + :meth:`Tensor.bitwise_not_`,None + ":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.bool`,:ref:`keeps_input_names-doc` + :meth:`Tensor.byte`,:ref:`keeps_input_names-doc` + :func:`torch.cat`,:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.cauchy_`,None + ":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc` + :meth:`Tensor.ceil_`,None + :meth:`Tensor.char`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc` + :meth:`Tensor.clamp_`,None + :meth:`Tensor.copy_`,:ref:`out_function_semantics-doc` + ":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc` + :meth:`Tensor.cos_`,None + ":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.cosh_`,None + :meth:`Tensor.cpu`,:ref:`keeps_input_names-doc` + :meth:`Tensor.cuda`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`removes_dimensions-doc` + :meth:`Tensor.data_ptr`,None + ":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc` + :meth:`Tensor.detach_`,None + ":attr:`Tensor.device`, :func:`torch.device`",None + ":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc` + :meth:`Tensor.digamma_`,None + :meth:`Tensor.dim`,None + ":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.dot`, :func:`torch.dot`",None + :meth:`Tensor.double`,:ref:`keeps_input_names-doc` + :meth:`Tensor.element_size`,None + :func:`torch.empty`,:ref:`factory-doc` + :func:`torch.empty_like`,:ref:`factory-doc` + ":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erf_`,None + ":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erfc_`,None + ":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erfinv_`,None + ":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc` + :meth:`Tensor.exp_`,None + :meth:`Tensor.expand`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc` + :meth:`Tensor.expm1_`,None + :meth:`Tensor.exponential_`,None + :meth:`Tensor.fill_`,None + ":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation + :meth:`Tensor.float`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc` + :meth:`Tensor.floor_`,None + ":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc` + :meth:`Tensor.frac_`,None + ":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.get_device`, :func:`torch.get_device`",None + :attr:`Tensor.grad`,None + ":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.half`,:ref:`keeps_input_names-doc` + :meth:`Tensor.has_names`,See documentation + ":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc` + :meth:`Tensor.index_fill_`,None + :meth:`Tensor.int`,:ref:`keeps_input_names-doc` + :meth:`Tensor.is_contiguous`,None + :attr:`Tensor.is_cuda`,None + ":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None + :attr:`Tensor.is_leaf`,None + :meth:`Tensor.is_pinned`,None + :meth:`Tensor.is_shared`,None + ":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None + :attr:`Tensor.is_sparse`,None + :func:`torch.is_tensor`,None + :meth:`Tensor.item`,None + ":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log10_`,None + ":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log1p_`,None + ":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log2_`,None + :meth:`Tensor.log_`,None + :meth:`Tensor.log_normal_`,None + ":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc` + :meth:`Tensor.logical_not_`,None + ":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc` + :meth:`Tensor.long`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc` + :func:`torch.manual_seed`,None + ":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc` + :meth:`Tensor.masked_fill_`,None + ":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors + ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` + ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` + ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc` + :attr:`Tensor.names`,See documentation + ":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc` + :attr:`Tensor.ndim`,None + :meth:`Tensor.ndimension`,None + ":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc` + :meth:`Tensor.neg_`,None + :func:`torch.normal`,:ref:`keeps_input_names-doc` + :meth:`Tensor.normal_`,None + ":meth:`Tensor.numel`, :func:`torch.numel`",None + :func:`torch.ones`,:ref:`factory-doc` + ":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.pow_`,None + ":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc` + :func:`torch.rand`,:ref:`factory-doc` + :func:`torch.rand`,:ref:`factory-doc` + :func:`torch.randn`,:ref:`factory-doc` + :func:`torch.randn`,:ref:`factory-doc` + :meth:`Tensor.random_`,None + ":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc` + :meth:`Tensor.reciprocal_`,None + :meth:`Tensor.refine_names`,See documentation + :meth:`Tensor.register_hook`,None + :meth:`Tensor.rename`,See documentation + :meth:`Tensor.rename_`,See documentation + :attr:`Tensor.requires_grad`,None + :meth:`Tensor.requires_grad_`,None + :meth:`Tensor.resize_`,Only allow resizes that do not change shape + :meth:`Tensor.resize_as_`,Only allow resizes that do not change shape + ":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc` + :meth:`Tensor.round_`,None + ":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc` + :meth:`Tensor.rsqrt_`,None + ":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc` + :meth:`Tensor.short`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sigmoid_`,None + ":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sign_`,None + ":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sin_`,None + ":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sinh_`,None + :meth:`Tensor.size`,None + ":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sqrt_`,None + ":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc` + :func:`torch.std_mean`,:ref:`removes_dimensions-doc` + :meth:`Tensor.stride`,None + ":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc` + :meth:`Tensor.tan_`,None + ":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.tanh_`,None + :func:`torch.tensor`,:ref:`factory-doc` + :meth:`Tensor.to`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc` + ":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc` + :meth:`Tensor.trunc_`,None + :meth:`Tensor.type`,None + :meth:`Tensor.type_as`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc` + :meth:`Tensor.unflatten`,See documentation + :meth:`Tensor.uniform_`,None + ":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc` + :func:`torch.var_mean`,:ref:`removes_dimensions-doc` + :meth:`Tensor.zero_`,None + :func:`torch.zeros`,:ref:`factory-doc` + + +.. _keeps_input_names-doc: + +Keeps input names +^^^^^^^^^^^^^^^^^ + +All pointwise unary functions follow this rule as well as some other unary functions. + +- Check names: None +- Propagate names: input tensor's names are propagated to the output. + +:: + + >>> x = torch.randn(3, 3, names=('N', 'C')) + >>> x.abs().names + ('N', 'C') + +.. _removes_dimensions-doc: + +Removes dimensions +^^^^^^^^^^^^^^^^^^ + +All reduction ops like :meth:`~Tensor.sum` remove dimensions by reducing +over the desired dimensions. Other operations like :meth:`~Tensor.select` and +:meth:`~Tensor.squeeze` remove dimensions. + +Wherever one can pass an integer dimension index to an operator, one can also pass +a dimension name. Functions that take lists of dimension indices can also take in a +list of dimension names. + +- Check names: If :attr:`dim` or :attr:`dims` is passed in as a list of names, + check that those names exist in :attr:`self`. +- Propagate names: If the dimensions of the input tensor specified by :attr:`dim` + or :attr:`dims` are not present in the output tensor, then the corresponding names + of those dimensions do not appear in ``output.names``. + +:: + + >>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W')) + >>> x.squeeze('N').names + ('C', 'H', 'W') + + >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) + >>> x.sum(['N', 'C']).names + ('H', 'W') + + # Reduction ops with keepdim=True don't actually remove dimensions. + >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) + >>> x.sum(['N', 'C'], keepdim=True).names + ('N', 'C', 'H', 'W') + + +.. _unifies_names_from_inputs-doc: + +Unifies names from inputs +^^^^^^^^^^^^^^^^^^^^^^^^^ + +All binary arithmetic ops follow this rule. Operations that broadcast still +broadcast positionally from the right to preserve compatibility with unnamed +tensors. To perform explicit broadcasting by names, use :meth:`Tensor.align_as`. + +- Check names: All names must match positionally from the right. i.e., in + ``tensor + other``, ``match(tensor.names[i], other.names[i])`` must be true for all + ``i`` in ``(-min(tensor.dim(), other.dim()) + 1, -1]``. +- Check names: Furthermore, all named dimensions must be aligned from the right. + During matching, if we match a named dimension ``A`` with an unnamed dimension + ``None``, then ``A`` must not appear in the tensor with the unnamed dimension. +- Propagate names: unify pairs of names from the right from both tensors to + produce output names. + +For example, + +:: + + # tensor: Tensor[ N, None] + # other: Tensor[None, C] + >>> tensor = torch.randn(3, 3, names=('N', None)) + >>> other = torch.randn(3, 3, names=(None, 'C')) + >>> (tensor + other).names + ('N', 'C') + +Check names: + +- ``match(tensor.names[-1], other.names[-1])`` is ``True`` +- ``match(tensor.names[-2], tensor.names[-2])`` is ``True`` +- Because we matched ``None`` in :attr:`tensor` with ``'C'``, + check to make sure ``'C'`` doesn't exist in :attr:`tensor` (it does not). +- Check to make sure ``'N'`` doesn't exists in :attr:`other` (it does not). + +Finally, the output names are computed with +``[unify('N', None), unify(None, 'C')] = ['N', 'C']`` + +More examples:: + + # Dimensions don't match from the right: + # tensor: Tensor[N, C] + # other: Tensor[ N] + >>> tensor = torch.randn(3, 3, names=('N', 'C')) + >>> other = torch.randn(3, names=('N',)) + >>> (tensor + other).names + RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims + ['N']: dim 'C' and dim 'N' are at the same position from the right but do + not match. + + # Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]: + # tensor: Tensor[N, None] + # other: Tensor[ N] + >>> tensor = torch.randn(3, 3, names=('N', None)) + >>> other = torch.randn(3, names=('N',)) + >>> (tensor + other).names + RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and + dims ['N', None]: dim 'N' appears in a different position from the right + across both lists. + +.. note:: + + In both of the last examples, it is possible to align the tensors by names + and then perform the addition. Use :meth:`Tensor.align_as` to align + tensors by name or :meth:`Tensor.align_to` to align tensors to a custom + dimension ordering. + +.. _permutes_dimensions-doc: + +Permutes dimensions +^^^^^^^^^^^^^^^^^^^ + +Some operations, like :meth:`Tensor.t()`, permute the order of dimensions. Dimension names +are attached to individual dimensions so they get permuted as well. + +If the operator takes in positional index :attr:`dim`, it is also able to take a dimension +name as :attr:`dim`. + +- Check names: If :attr:`dim` is passed as a name, check that it exists in the tensor. +- Propagate names: Permute dimension names in the same way as the dimensions that are + being permuted. + +:: + + >>> x = torch.randn(3, 3, names=('N', 'C')) + >>> x.transpose('N', 'C').names + ('C', 'N') + +.. _contracts_away_dims-doc: + +Contracts away dims +^^^^^^^^^^^^^^^^^^^ + +Matrix multiply functions follow some variant of this. Let's go through +:func:`torch.mm` first and then generalize the rule for batch matrix multiplication. + +For ``torch.mm(tensor, other)``: + +- Check names: None +- Propagate names: result names are ``(tensor.names[-2], other.names[-1])``. + +:: + + >>> x = torch.randn(3, 3, names=('N', 'D')) + >>> y = torch.randn(3, 3, names=('in', 'out')) + >>> x.mm(y).names + ('N', 'out') + +Inherently, a matrix multiplication performs a dot product over two dimensions, +collapsing them. When two tensors are matrix-multipled, the contracted dimensions +disappear and do not show up in the output tensor. + +:func:`torch.mv`, :func:`torch.dot` work in a similar way: name inference does not +check input names and removes the dimensions that are involved in the dot product: + +:: + + >>> x = torch.randn(3, 3, names=('N', 'D')) + >>> y = torch.randn(3, names=('something',)) + >>> x.mv(y).names + ('N',) + +Now, let's take a look at ``torch.matmul(tensor, other)``. Assume that ``tensor.dim() >= 2`` +and ``other.dim() >= 2``. + +- Check names: Check that the batch dimensions of the inputs are aligned and broadcastable. + See :ref:`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned. +- Propagate names: result names are obtained by unifying the batch dimensions and removing + the contracted dimensions: + ``unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])``. + +Examples:: + + # Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F']. + # 'A', 'B' are batch dimensions. + >>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D)) + >>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F)) + >>> torch.matmul(x, y).names + ('A', 'B', 'C', 'F') + + +Finally, there are fused ``add`` versions of many matmul functions. i.e., :func:`addmm` +and :func:`addmv`. These are treated as composing name inference for i.e. :func:`mm` and +name inference for :func:`add`. + +.. _factory-doc: + +Factory functions +^^^^^^^^^^^^^^^^^ + + +Factory functions now take a new :attr:`names` argument that associates a name +with each dimension. + +:: + + >>> torch.zeros(2, 3, names=('N', 'C')) + tensor([[0., 0., 0.], + [0., 0., 0.]], names=('N', 'C')) + +.. _out_function_semantics-doc: + +out function and in-place variants +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A tensor specified as an ``out=`` tensor has the following behavior: + +- If it has no named dimensions, then the names computed from the operation + get propagated to it. +- If it has any named dimensions, then the names computed from the operation + must be exactly equal to the existing names. Otherwise, the operation errors. + +All in-place methods modify inputs to have names equal to the computed names +from name inference. For example, + +:: + + >>> x = torch.randn(3, 3) + >>> y = torch.randn(3, 3, names=('N', 'C')) + >>> x.names + (None, None) + + >>> x += y + >>> x.names + ('N', 'C') + diff --git a/docs/source/named_tensor.rst b/docs/source/named_tensor.rst new file mode 100644 index 00000000000..eeeaf05daed --- /dev/null +++ b/docs/source/named_tensor.rst @@ -0,0 +1,319 @@ +.. currentmodule:: torch + +.. _named_tensors-doc: + +Named Tensors +============= + +Named Tensors aim to make tensors easier to use by allowing users to associate +explicit names with tensor dimensions. In most cases, operations that take +dimension parameters will accept dimension names, avoiding the need to track +dimensions by position. In addition, named tensors use names to automatically +check that APIs are being used correctly at runtime, providing extra safety. +Names can also be used to rearrange dimensions, for example, to support +"broadcasting by name" rather than "broadcasting by position". + +.. warning:: + The named tensor API is experimental and subject to change. + +Creating named tensors +---------------------- + +Factory functions now take a new :attr:`names` argument that associates a name +with each dimension. + +:: + + >>> torch.zeros(2, 3, names=('N', 'C')) + tensor([[0., 0., 0.], + [0., 0., 0.]], names=('N', 'C')) + +Named dimensions, like regular Tensor dimensions, are ordered. +``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``. + +The following factory functions support named tensors: + +- :func:`torch.empty` +- :func:`torch.rand` +- :func:`torch.randn` +- :func:`torch.ones` +- :func:`torch.tensor` +- :func:`torch.zeros` + +Named dimensions +---------------- + +See :attr:`~Tensor.names` for restrictions on tensor names. + +Use :attr:`~Tensor.names` to access the dimension names of a tensor and +:meth:`~Tensor.rename` to rename named dimensions. + +:: + + >>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W')) + >>> imgs.names + ('N', 'C', 'H', 'W') + + >>> renamed_imgs = imgs.rename(H='height', W='width') + >>> renamed_imgs.names + ('N', 'C', 'height', 'width) + + +Named tensors can coexist with unnamed tensors; named tensors are instances of +:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named +tensors do not require all dimensions to be named. + +:: + + >>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W')) + >>> imgs.names + (None, 'C', 'H', 'W') + +Name propagation semantics +-------------------------- + +Named tensors use names to automatically check that APIs are being called +correctly at runtime. This occurs in a process called *name inference*. +More formally, name inference consists of the following two steps: + +- **Check names**: an operator may perform automatic checks at runtime that + check that certain dimension names must match. +- **Propagate names**: name inference propagates names to output tensors. + +All operations that support named tensors propagate names. + +:: + + >>> x = torch.randn(3, 3, names=('N', 'C')) + >>> x.abs().names + ('N', 'C') + + +.. _match_semantics-doc: + +match semantics +^^^^^^^^^^^^^^^ + +Two names *match* if they are equal (string equality) or if at least one is ``None``. +Nones are essentially a special "wildcard" name. + +``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs. +It returns the more *specific* of the two names, if they match. If the names do not match, +then it errors. + +.. note:: + In practice, when working with named tensors, one should avoid having unnamed + dimensions because their handling can be complicated. It is recommended to lift + all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`. + + +Basic name inference rules +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Let's see how ``match`` and ``unify`` are used in name inference in the case of +adding two one-dim tensors with no broadcasting. + +:: + + x = torch.randn(3, names=('X',)) + y = torch.randn(3) + z = torch.randn(3, names=('Z',)) + +**Check names**: check that the names of the two tensors *match*. + +For the following examples: + +:: + + >>> # x + y # match('X', None) is True + >>> # x + z # match('X', 'Z') is False + >>> # x + x # match('X', 'X') is True + + >>> x + z + Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match. + +**Propagate names**: *unify* the names to select which one to propagate. +In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more +specific than ``None``. + +:: + + >>> (x + y).names + ('X',) + >>> (x + x).names + ('X',) + +For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`. +Here are two common operations that may be useful to go over: + +- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc` +- Matrix multiplication ops: :ref:`contracts_away_dims-doc` + +Explicit alignment by names +--------------------------- + +Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions +by name to a specified ordering. This is useful for performing "broadcasting by names". + +:: + + # This function is agnostic to the dimension ordering of `input`, + # as long as it has a `C` dimension somewhere. + def scale_channels(input, scale): + scale = scale.refine_names('C') + return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names='C') + >>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D') + + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + +Manipulating dimensions +----------------------- + +Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without +mentioning all of them as in required by :meth:`~Tensor.permute`. + +:: + + >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) + >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') + + # Move the F (dim 5) and E dimension (dim 4) to the front while keeping + # the rest in the same order + >>> tensor.permute(5, 4, 0, 1, 2, 3) + >>> named_tensor.align_to('F', 'E', ...) # Use '...' instead in Python 2 + +Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten +dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view` +and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code. + +:: + + >>> imgs = torch.randn(32, 3, 128, 128) + >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') + + >>> flat_imgs = imgs.view(32, -1) + >>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features') + >>> named_flat_imgs.names + ('N', 'features') + + >>> unflattened_imgs = imgs.view(32, 3, 128, 128) + >>> unflattened_named_imgs = named_flat_imgs.unflatten( + 'features', [('C', 3), ('H', 128), ('W', 128)]) + +.. _named_tensors_autograd-doc: + +Autograd support +---------------- + +Autograd currently supports named tensors in a limited manner: autograd ignores +names on all tensors. Gradient computation is still correct but we lose the +safety that names give us. + +:: + + >>> x = torch.randn(3, names=('D',)) + >>> weight = torch.randn(3, names=('D',), requires_grad=True) + >>> loss = (x - weight).abs() + >>> grad_loss = torch.randn(3) + >>> loss.backward(grad_loss) + >>> weight.grad # Unnamed for now. Will be named in the future + tensor([-1.8107, -0.6357, 0.0783]) + + >>> weight.grad.zero_() + >>> grad_loss = grad_loss.refine_names('C') + >>> loss = (x - weight).abs() + # Ideally we'd check that the names of loss and grad_loss match but we don't yet. + >>> loss.backward(grad_loss) + >>> weight.grad + tensor([-1.8107, -0.6357, 0.0783]) + +Currently supported operations and subsystems +--------------------------------------------- + +Operators +^^^^^^^^^ + +See :ref:`name_inference_reference-doc` for a full list of the supported torch and +tensor operations. We do not yet support the following that is not covered by the link: + +- indexing, advanced indexing. + +For ``torch.nn.functional`` operators, we support the following: + +- :func:`torch.nn.functional.relu` +- :func:`torch.nn.functional.softmax` +- :func:`torch.nn.functional.log_softmax` +- :func:`torch.nn.functional.tanh` +- :func:`torch.nn.functional.sigmoid` +- :func:`torch.nn.functional.dropout` + +Subsystems +^^^^^^^^^^ + +Autograd is supported, see :ref:`named_tensors_autograd-doc`. +Because gradients are currently unnamed, optimizers may work but are untested. + +NN modules are currently unsupported. This can lead to the following when calling +modules with named tensor inputs: + +- NN module parameters are unnamed, so outputs may be partially named. +- NN module forward passes have code that don't support named tensors and will + error out appropriately. + +We also do not support the following subsystems, though some may work out +of the box: + +- distributions +- serialization (:func:`torch.load`, :func:`torch.save`) +- multiprocessing +- JIT +- distributed +- ONNX + +If any of these would help your use case, please +`search if an issue has already been filed `_ +and if not, `file one `_. + +Named tensor API reference +-------------------------- + +In this section please find the documentation for named tensor specific APIs. +For a comprehensive reference for how names are propagated through other PyTorch +operators, see :ref:`name_inference_reference-doc`. + +.. class:: Tensor() + :noindex: + + .. autoattribute:: names + .. automethod:: rename + .. automethod:: rename_ + .. automethod:: refine_names + + .. automethod:: align_as + .. automethod:: align_to + + .. automethod:: unflatten + .. py:method:: flatten(dims, out_dim) -> Tensor + + Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`. + + All of `dims` must be consecutive in order in the :attr:`self` tensor, + but not necessary contiguous in memory. + + Examples:: + + >>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W')) + >>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features') + >>> flat_imgs.names, flat_imgs.shape + (('N', 'features'), torch.Size([32, 49152])) + + .. warning:: + The named tensor API is experimental and subject to change. + diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 0de8ad313b6..285d647fb9e 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -170,6 +170,7 @@ view of a storage and defines numeric operations on it. .. automethod:: addr .. automethod:: addr_ .. automethod:: allclose + .. automethod:: angle .. automethod:: apply_ .. automethod:: argmax .. automethod:: argmin @@ -206,6 +207,7 @@ view of a storage and defines numeric operations on it. .. automethod:: clone .. automethod:: contiguous .. automethod:: copy_ + .. automethod:: conj .. automethod:: cos .. automethod:: cos_ .. automethod:: cosh @@ -276,6 +278,7 @@ view of a storage and defines numeric operations on it. .. automethod:: hardshrink .. automethod:: histc .. automethod:: ifft + .. automethod:: imag .. automethod:: index_add_ .. automethod:: index_add .. automethod:: index_copy_ @@ -381,6 +384,7 @@ view of a storage and defines numeric operations on it. .. automethod:: register_hook .. automethod:: remainder .. automethod:: remainder_ + .. automethod:: real .. automethod:: renorm .. automethod:: renorm_ .. automethod:: repeat diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 0aee7e4bf76..51607e6e5a7 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -188,12 +188,14 @@ Pointwise Ops .. autofunction:: add .. autofunction:: addcdiv .. autofunction:: addcmul +.. autofunction:: angle .. autofunction:: asin .. autofunction:: atan .. autofunction:: atan2 .. autofunction:: bitwise_not .. autofunction:: ceil .. autofunction:: clamp +.. autofunction:: conj .. autofunction:: cos .. autofunction:: cosh .. autofunction:: div @@ -206,6 +208,7 @@ Pointwise Ops .. autofunction:: floor .. autofunction:: fmod .. autofunction:: frac +.. autofunction:: imag .. autofunction:: lerp .. autofunction:: log .. autofunction:: log10 @@ -217,6 +220,7 @@ Pointwise Ops .. autofunction:: mvlgamma .. autofunction:: neg .. autofunction:: pow +.. autofunction:: real .. autofunction:: reciprocal .. autofunction:: remainder .. autofunction:: round diff --git a/test/common_nn.py b/test/common_nn.py index bc9441886e0..694cf596007 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -56,18 +56,6 @@ module_tests = [ desc='no_bias', reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) ), - dict( - module_name='Linear', - constructor_args=(10, 8), - input_size=(0, 10), - desc='zero_batch', - ), - dict( - module_name='Linear', - constructor_args=(10, 8, False), - input_size=(0, 10), - desc='zero_batch_no_bias', - ), dict( module_name='Threshold', constructor_args=(2., 1.), diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 9e8129e8a7d..f2a7d7d0a45 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -161,6 +161,17 @@ TEST_F(FunctionalTest, HingeEmbeddingLoss) { ASSERT_TRUE(output.allclose(expected)); } +TEST_F(FunctionalTest, MultiMarginLoss) { + auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); + auto input = torch::tensor({{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, torch::requires_grad()); + auto target = torch::tensor({2, 1, 0}, torch::kLong); + auto output = F::multi_margin_loss( + input, target, MultiMarginLossOptions().margin(2).weight(weight)); + auto expected = torch::tensor({0.305556}, torch::kFloat); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); +} + TEST_F(FunctionalTest, CosineEmbeddingLoss) { auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}}); auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}}); @@ -254,6 +265,32 @@ TEST_F(FunctionalTest, ELU) { } } +TEST_F(FunctionalTest, SELU) { + { + const double scale = 1.0507009873554804934193349852946; + const double alpha = 1.6732632423543772848170429916717; + for (const auto inplace : {false, true}) { + auto input = torch::randn({5, 5}); + auto expected = scale * + (torch::max(torch::zeros_like(input), input) + + torch::min( + torch::zeros_like(input), alpha * (torch::exp(input) - 1))); + auto output = F::selu(input, inplace); + + ASSERT_TRUE(output.allclose(expected)); + if (inplace) { + ASSERT_TRUE(input.allclose(expected)); + } + } + } + { + auto input = torch::arange(0, 9, torch::kDouble).view({3, 3}); + auto output = F::selu(input); + auto expected = F::selu(input, false); + ASSERT_TRUE(output.allclose(expected)); + } +} + TEST_F(FunctionalTest, Hardshrink) { const auto size = 3; for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { @@ -371,3 +408,14 @@ TEST_F(FunctionalTest, LogSigmoid) { auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x)))); ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); } + +TEST_F(FunctionalTest, Softmax) { + auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); + auto output = F::softmax(input, /*dim=*/1); + auto sum = torch::sum(torch::exp(input), 1); + + for (int i = 0; i < 2; i++) { + auto expected = torch::exp(input[i]) / sum[i]; + ASSERT_TRUE(torch::allclose(output[i], expected)); + } +} diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp index f6a56dd1b57..e8ec5c69b5d 100644 --- a/test/cpp/api/modulelist.cpp +++ b/test/cpp/api/modulelist.cpp @@ -282,7 +282,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) { " (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n" " (2): torch::nn::Dropout(rate=0.5)\n" " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n" - " (4): torch::nn::Embedding(count=4, dimension=10)\n" + " (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n" ")"); } diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 74362fb39ba..6235d716d78 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -800,6 +800,21 @@ TEST_F(ModulesTest, EmbeddingList) { ASSERT_EQ(y.size(2), 4); } +TEST_F(ModulesTest, EmbeddingFromPretrained) { + auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); + Embedding embedding = torch::nn::Embedding::from_pretrained(weight); + auto input = torch::tensor({1}, torch::kLong); + ASSERT_TRUE(torch::allclose(embedding(input), torch::tensor({4.0000, 5.1000, 6.3000}))); +} + +TEST_F(ModulesTest, EmbeddingBagFromPretrained) { + auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); + EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight); + auto input = torch::zeros({{1, 2}}, torch::kLong); + input[0] = torch::tensor({1, 0}); + ASSERT_TRUE(torch::allclose(embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500}))); +} + TEST_F(ModulesTest, Dropout) { Dropout dropout(0.5); torch::Tensor x = torch::ones(100, torch::requires_grad()); @@ -979,6 +994,20 @@ TEST_F(ModulesTest, HingeEmbeddingLoss) { ASSERT_EQ(input.sizes(), input.grad().sizes()); } +TEST_F(ModulesTest, MultiMarginLoss) { + auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); + MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight)); + auto input = torch::tensor({{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, torch::requires_grad()); + auto target = torch::tensor({2, 1, 0}, torch::kLong); + auto output = loss->forward(input, target); + auto expected = torch::tensor({0.305556}, torch::kFloat); + auto s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input.sizes(), input.grad().sizes()); +} + TEST_F(ModulesTest, CosineEmbeddingLoss) { CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5)); auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}}, torch::requires_grad()); @@ -1040,6 +1069,23 @@ TEST_F(ModulesTest, ELU) { } } +TEST_F(ModulesTest, SELU) { + SELU model; + auto input = torch::randn({5, 5}, torch::requires_grad()); + auto output = model->forward(input); + const double scale = 1.0507009873554804934193349852946; + const double alpha = 1.6732632423543772848170429916717; + auto zero = torch::zeros_like(input); + auto expected = scale * + (torch::max(zero, input) + + torch::min(zero, alpha * (torch::exp(input) - 1))); + auto s = output.sum(); + s.backward(); + + ASSERT_EQ(s.ndimension(), 0); + ASSERT_TRUE(output.allclose(expected)); +} + TEST_F(ModulesTest, Hardshrink) { const auto size = 3; for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { @@ -1131,6 +1177,18 @@ TEST_F(ModulesTest, LogSigmoid) { ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); } +TEST_F(ModulesTest, Softmax) { + Softmax m(/*dim=*/1); + auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); + auto output = m(input); + auto sum = torch::sum(torch::exp(input), 1); + + for (int i = 0; i < 2; i++) { + auto expected = torch::exp(input[i]) / sum[i]; + ASSERT_TRUE(torch::allclose(output[i], expected)); + } +} + TEST_F(ModulesTest, PrettyPrintIdentity) { ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()"); } @@ -1290,8 +1348,29 @@ TEST_F(ModulesTest, PrettyPrintBatchNorm) { TEST_F(ModulesTest, PrettyPrintEmbedding) { ASSERT_EQ( - c10::str(Embedding(10, 2)), - "torch::nn::Embedding(count=10, dimension=2)"); + c10::str(Embedding(EmbeddingOptions(10, 2))), + "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)"); + ASSERT_EQ( + c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))), + "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)"); + ASSERT_EQ( + c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))), + "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"); +} + +TEST_F(ModulesTest, PrettyPrintEmbeddingBag) { + ASSERT_EQ( + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))), + "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)"); + ASSERT_EQ( + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))), + "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)"); + ASSERT_EQ( + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))), + "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"); + ASSERT_EQ( + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode("sum"))), + "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=sum)"); } TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) { @@ -1339,7 +1418,7 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) { TestModule() : torch::nn::Module("TestModule"), fc(register_module("fc", torch::nn::Linear(4, 5))), - table(register_module("table", torch::nn::Embedding(10, 2))), + table(register_module("table", torch::nn::Embedding(EmbeddingOptions(10, 2)))), inner(register_module("inner", std::make_shared())) { } @@ -1352,10 +1431,10 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) { c10::str(TestModule{}), "TestModule(\n" " (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n" - " (table): torch::nn::Embedding(count=10, dimension=2)\n" + " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n" " (inner): InnerTestModule(\n" " (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n" - " (table): torch::nn::Embedding(count=10, dimension=2)\n" + " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n" " )\n" ")"); } @@ -1366,6 +1445,12 @@ TEST_F(ModulesTest, PrettyPrintELU) { "torch::nn::ELU(alpha=42.42, inplace=true)"); } +TEST_F(ModulesTest, PrettyPrintSELU) { + ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()"); + ASSERT_EQ(c10::str(SELU(SELUOptions().inplace(true))), + "torch::nn::SELU(inplace=true)"); +} + TEST_F(ModulesTest, PrettyPrintHardshrink) { ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)"); ASSERT_EQ(c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))), @@ -1391,3 +1476,7 @@ TEST_F(ModulesTest, PrettyPrintLeakyReLU) { TEST_F(ModulesTest, PrettyPrintLogSigmoid) { ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()"); } + +TEST_F(ModulesTest, PrettyPrintSoftmax) { + ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)"); +} diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index 87c6a57860c..54fb2fde802 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -393,7 +393,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) { " (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n" " (2): torch::nn::Dropout(rate=0.5)\n" " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n" - " (4): torch::nn::Embedding(count=4, dimension=10)\n" + " (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n" ")"); @@ -412,7 +412,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) { " (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n" " (dropout): torch::nn::Dropout(rate=0.5)\n" " (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n" - " (embedding): torch::nn::Embedding(count=4, dimension=10)\n" + " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n" ")"); } diff --git a/test/rpc_test.py b/test/rpc_test.py index 505a45bdda0..21abd2b41ae 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 from __future__ import absolute_import, division, print_function, unicode_literals import concurrent.futures diff --git a/test/test_dist_autograd_fork.py b/test/test_dist_autograd_fork.py old mode 100644 new mode 100755 diff --git a/test/test_dist_autograd_spawn.py b/test/test_dist_autograd_spawn.py old mode 100644 new mode 100755 diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py index 8b95091d5a4..77e896b7ee6 100644 --- a/test/test_docs_coverage.py +++ b/test/test_docs_coverage.py @@ -75,8 +75,13 @@ class TestDocCoverage(unittest.TestCase): def test_tensor(self): in_rst = self.parse_rst('tensors.rst', r2) + whitelist = { + 'names', 'unflatten', 'align_as', 'rename_', 'refine_names', 'align_to', + 'has_names', 'rename', + } classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor] has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__) + has_docstring -= whitelist self.assertEqual( has_docstring, in_rst, textwrap.dedent(''' diff --git a/test/test_jit.py b/test/test_jit.py index 700755babaa..8badaf0afea 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3400,6 +3400,7 @@ def foo(x): cu.define(full) def test_namedtuple_python(self): + global MyTuple, MyMod # see [local resolution in python] MyTuple = namedtuple('MyTuple', ['a']) @torch.jit.unused @@ -15000,6 +15001,7 @@ a") self.checkScript(fn, ()) def test_named_tuple_redefine(self): + global _1, _2 _1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) _2 = namedtuple('GoogLeNetOutputs', ['different']) @@ -15010,6 +15012,7 @@ a") return x def test_named_tuple_py2(self): + global _GoogLeNetOutputs # see [local resolution in python] _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) @torch.jit.script @@ -15024,6 +15027,7 @@ a") self.assertEqual(out.aux_logits1, vals[2]) def test_named_tuple_good_error(self): + global _GoogLeNetOutputs # see [local resolution in python] _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) @torch.jit.script @@ -19370,6 +19374,7 @@ class TestClassType(JitTestCase): self.attr = x def test_class_type_as_param(self): + global FooTest # see [local resolution in python] @torch.jit.script # noqa: B903 class FooTest(object): def __init__(self, x): @@ -19512,6 +19517,7 @@ class TestClassType(JitTestCase): self.assertEqual(2 * input, output) def test_python_interop(self): + global Foo # see [local resolution in python] @torch.jit.script # noqa: B903 class Foo(object): def __init__(self, x, y): @@ -19538,6 +19544,7 @@ class TestClassType(JitTestCase): self.assertEqual(y, f2.y) def test_class_specialization(self): + global Foo # see [local resolution in python] @torch.jit.script # noqa: B903 class Foo(object): def __init__(self, x, y): @@ -19562,6 +19569,7 @@ class TestClassType(JitTestCase): FileCheck().check_count("Double(*, *) = prim::GetAttr", 4).run(graphstr) def test_class_sorting(self): + global Foo # see [local resolution in python] @torch.jit.script # noqa: B903 class Foo(object): def __init__(self, x): @@ -19675,6 +19683,7 @@ class TestClassType(JitTestCase): self.assertEqual(3 * input, output) def test_interface(self): + global Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2 @torch.jit.script class Foo(object): def __init__(self): @@ -19836,6 +19845,7 @@ class TestClassType(JitTestCase): # NamedTuple inheritance errors def test_overloaded_fn(self): + global Foo, MyClass # see [local resolution in python] @torch.jit.script class Foo(object): def __init__(self, x): @@ -19991,6 +20001,7 @@ class TestClassType(JitTestCase): return Foo(torch.tensor(1)) + Foo(torch.tensor(1)) def test_cast_overloads(self): + global Foo # see [local resolution in python] @torch.jit.script class Foo(object): def __init__(self, val): diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index b42d68d8fe5..249a89c8739 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -110,6 +110,8 @@ class TestScriptPy3(JitTestCase): FileCheck().check_not('TupleConstruct').run(foo.graph) def test_named_tuple_type_annotation(self): + global MyCoolNamedTuple # see [local resolution in python] + class MyCoolNamedTuple(NamedTuple): a : int b : float diff --git a/test/test_quantized.py b/test/test_quantized.py index 4b8f312f77c..5255cbd6640 100644 --- a/test/test_quantized.py +++ b/test/test_quantized.py @@ -901,8 +901,6 @@ class TestQuantizedOps(TestCase): self.assertEqual(Y, qY.dequantize()) """Tests the correctness of the quantized equal op.""" - @unittest.skip("temporarily disable until failures are fixed. " + - "See https://github.com/pytorch/pytorch/issues/26279") @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), qparams=hu.qparams()), X2=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), @@ -949,6 +947,8 @@ class TestQuantizedOps(TestCase): return False if qX.shape != qX2.shape: return False + if qX.dtype != qX2.dtype: + return False if qX.qscheme() == torch.per_tensor_affine: if qX.q_scale() != qX2.q_scale(): return False diff --git a/test/test_rpc_fork.py b/test/test_rpc_fork.py old mode 100644 new mode 100755 diff --git a/test/test_rpc_spawn.py b/test/test_rpc_spawn.py old mode 100644 new mode 100755 diff --git a/test/test_torch.py b/test/test_torch.py index be0a16ef04a..7d6473db989 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -245,14 +245,6 @@ class _TestTorchMixin(object): 'to_dense', 'sparse_resize_', 'sparse_resize_and_clear_', - 'align_to', # BUILD_NAMEDTENSOR only - 'align_as', # BUILD_NAMEDTENSOR only - 'rename', # BUILD_NAMEDTENSOR only - 'rename_', # BUILD_NAMEDTENSOR only - 'has_names', # BUILD_NAMEDTENSOR only - 'rename', # BUILD_NAMEDTENSOR only - 'refine_names', # BUILD_NAMEDTENSOR only - 'unflatten', # BUILD_NAMEDTENSOR only ) test_namespace(torch.nn) test_namespace(torch.nn.functional, 'assert_int_or_pair', 'feature_alpha_dropout') diff --git a/test/test_type_hints.py b/test/test_type_hints.py index c244bcc6121..6f2eadd5309 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -61,7 +61,15 @@ def get_all_examples(): This function grabs (hopefully all) examples from the torch documentation strings and puts them in one nonsensical module returned as a string. """ - blacklist = {"_np"} + blacklist = { + "_np", + "refine_names", + "rename", + "names", + "align_as", + "align_to", + "unflatten", + } allexamples = "" example_file_lines = [ diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py old mode 100644 new mode 100755 diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 9ee4c4f476c..299e82fe42e 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -430,6 +430,23 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje END_HANDLE_TH_ERRORS } +static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "numel(Tensor input)", + }, /*traceable=*/false); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if (r.idx == 0) { + return wrap(r.tensor(0).numel()); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + // generated methods start here ${py_methods} @@ -448,6 +465,7 @@ static PyMethodDef torch_functions[] = { {"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"numel", (PyCFunction)(void(*)(void))THPVariable_numel, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index e04660f833d..c12d1b8bc24 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -181,6 +181,14 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) END_HANDLE_TH_ERRORS } +static PyObject * THPVariable_numel(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + auto& self_ = reinterpret_cast(self)->cdata; + return THPUtils_packInt64(self_.numel()); + END_HANDLE_TH_ERRORS +} + static Tensor dispatch_contiguous(const Tensor & self, at::MemoryFormat memory_format) { AutoNoGIL no_gil; OptionalDeviceGuard device_guard(device_of(self)); @@ -781,6 +789,7 @@ PyMethodDef variable_methods[] = { {"new_ones", (PyCFunction)(void(*)(void))THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL}, {"new_tensor", (PyCFunction)(void(*)(void))THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, + {"numel", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL}, {"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL}, {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL}, {"requires_grad_", (PyCFunction)(void(*)(void))THPVariable_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/tools/clang_format.py b/tools/clang_format.py old mode 100644 new mode 100755 diff --git a/tools/clang_tidy.py b/tools/clang_tidy.py old mode 100644 new mode 100755 diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 563c8ec8379..564db0e671d 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -375,6 +375,7 @@ def gen_jit_dispatch(declarations, out, template_path, disable_autograd=False): return [sorted(g, key=declkey) for g in grouped_decls] # We need to add methods implemented manually in TensorImpl + # TODO: This seems to claim sizes() returns an int64_t. Really? tensor_impl_methods = [{ 'name': name, 'api_name': name, @@ -382,7 +383,7 @@ def gen_jit_dispatch(declarations, out, template_path, disable_autograd=False): 'method_of': ['Tensor'], 'arguments': [{'name': 'self', 'simple_type': 'Tensor'}], 'returns': [{'name': 'result', 'type': 'int64_t', 'dynamic_type': 'int64_t', 'simple_type': 'int64_t'}], - } for name in ['sizes', 'strides', 'dim']] + } for name in ['sizes', 'strides', 'dim', 'numel']] aten_decls = load_aten_declarations(declarations) + tensor_impl_methods jit_decls = [d for d in aten_decls if is_jit_op(d)] diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 794f9202f14..6157041a587 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -416,6 +416,7 @@ def gen_pyi(declarations_path, out): 'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'], 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'], 'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'], + 'numel': ['def numel(self: Tensor) -> _int: ...'], 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf," " *, out: Optional[Tensor]=None) -> Tensor: ..."], 'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."], @@ -501,6 +502,7 @@ def gen_pyi(declarations_path, out): 'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'], 'element_size': ['def element_size(self) -> _int: ...'], 'dim': ['def dim(self) -> _int: ...'], + 'numel': ['def numel(self) -> _int: ...'], 'ndimension': ['def ndimension(self) -> _int: ...'], 'nelement': ['def nelement(self) -> _int: ...'], 'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: _bool=False) -> Tensor: ...'], diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 90359cef759..e431f52404a 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -16,25 +16,25 @@ from torch._utils_internal import get_source_lines_and_file boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484 -def createResolutionCallback(frames_up=0): +def createResolutionCallbackFromFrame(frames_up=0): """ Creates a function which, given a string variable name, returns the value of the variable in the scope of the caller of - the function which called createResolutionCallback (by default). + the function which called createResolutionCallbackFromFrame (by default). This is used to enable access in-scope Python variables inside TorchScript fragments. frames_up is number of additional frames to go up on the stack. The default value is 0, which correspond to the frame of the caller - of createResolutionCallback. Also for example, if frames_up is set - to 1, then the frame of the caller's caller of createResolutionCallback + of createResolutionCallbackFromFrame. Also for example, if frames_up is set + to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame will be taken. For example, the following program prints 2:: def bar(): - cb = createResolutionCallback(1) + cb = createResolutionCallbackFromFrame(1) print(cb("foo")) def baz(): @@ -75,6 +75,48 @@ def get_closure(fn): return captures +# [local resolution in python] +# Depending on where a variable is defined, and where it is used, we may +# or may not be able to recover its value when recursively compiling a +# script function. Remember in the general case, a module or function is +# first defined and then later scripted. This means we do not have a +# chance to capture the active frames when the function is defined. Hence any +# name resolution has to happen later on the created closure. The way +# python captures type annotations restricts what we can recover. The +# follow example illustrates the different cases: +# +# class MyGlobalClass: +# ... +# def my_local_scope(): +# @torch.jit.script +# class MyClass: +# ... +# @torch.jit.script +# class MyClassUsedAsVar: +# ... +# def eg(x: MyClass, y: MyGlobalClass): +# a_local_capture : Foo +# return MyClassUsedAsVar(x) +# +# MyGlobalClass is defined in the __globals__ dictionary of function +# 'eg', so it is always recoverable. my_local_scope introduces a new local +# variable scope in the function. Classes defined here are only visible as +# local variables. For the case of MyClassUsedAsVar, it is captured +# because it is used as a variable inside the body of the function, and we +# can resolve it using the captures returned from `get_closure`. However, +# the type annotations are not captured by the closure. In Python +# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be availiable as +# annotations on `eg``, but starting in Python 4.0, they will represented as +# strings and no longer present. Furthermore, since the body of `eg` does +# not reference those names, they do not appear in the list of closed over +# variables. In Python 2.x, type annotations are in comments, leading to a +# similar situation where their definitions are not available. We anticipate +# that most users will not run into this issue because their modules and +# functions will be defined at a global scope like MyGlobalClass. In cases +# where they are not, it is possible to work around issues by declaring the +# values global in the function. + + def createResolutionCallbackFromClosure(fn): """ @@ -178,11 +220,12 @@ class FunctionModifiers(object): def export(fn): """ - This decorator indicates that a method is used as an entry point into a - ``ScriptModule`` and should be compiled. ``forward`` implicitly is assumbed to be an - entry point, so it does not need this decorator. Functions and methods - called from ``forward`` are compiled as they are seen, so they do not need - this decorator either. + This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a + :class:`ScriptModule` and should be compiled. + + ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. + Functions and methods called from ``forward`` are compiled as they are seen + by the compiler, so they do not need this decorator either. Example (using ``@torch.jit.export`` on a method): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index e58eb37725b..2b6f6e20f89 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -282,6 +282,53 @@ addr_(beta=1, alpha=1, vec1, vec2) -> Tensor In-place version of :meth:`~Tensor.addr` """) +add_docstr_all('align_as', + r""" +align_as(other) -> Tensor + +Permutes the dimensions of the :attr:`self` tensor to match the dimension order +in the :attr:`other` tensor, adding size-one dims for any new names. + +This operation is useful for explicit broadcasting by names (see examples). + +All of the dims of :attr:`self` must be named in order to use this method. +The resulting tensor is a view on the original tensor. + +All dimension names of :attr:`self` must be present in ``other.names``. +:attr:`other` may contain named dimensions that are not in ``self.names``; +the output tensor has a size-one dimension for each of those new names. + +To align a tensor to a specific order, use :meth:`~Tensor.align_to`. + +Examples:: + + # Example 1: Applying a mask + >>> mask = torch.randint(2, [127, 128], dtype=torch.bool).refine_names('W', 'H') + >>> imgs = torch.randn(32, 128, 127, 3, names=('N', 'H', 'W', 'C')) + >>> imgs.masked_fill_(mask.align_as(imgs), 0) + + + # Example 2: Applying a per-channel-scale + def scale_channels(input, scale): + scale = scale.refine_names('C') + return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names='C') + >>> imgs = torch.rand(32, 128, 128, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(32, num_channels, 128, 128, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 128, 128, 128, names=('N', 'C', 'H', 'W', 'D')) + + # scale_channels is agnostic to the dimension order of the input + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + +.. warning:: + The named tensor API is experimental and subject to change. + +""") + add_docstr_all('all', r""" .. function:: all() -> bool @@ -332,6 +379,13 @@ allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor See :func:`torch.allclose` """) +add_docstr_all('angle', + r""" +angle() -> Tensor + +See :func:`torch.angle` +""") + add_docstr_all('any', r""" .. function:: any() -> bool @@ -637,6 +691,13 @@ Args: cases, this argument has no effect. """) +add_docstr_all('conj', + r""" +conj() -> Tensor + +See :func:`torch.conj` +""") + add_docstr_all('cos', r""" cos() -> Tensor @@ -999,6 +1060,13 @@ flip(dims) -> Tensor See :func:`torch.flip` """) +add_docstr_all('real', + r""" +real() -> Tensor + +See :func:`torch.real` +""") + add_docstr_all('roll', r""" roll(shifts, dims) -> Tensor @@ -1095,6 +1163,13 @@ ger(vec2) -> Tensor See :func:`torch.ger` """) +add_docstr_all('imag', + r""" +imag() -> Tensor + +See :func:`torch.imag` +""") + add_docstr_all('indices', r""" indices() -> Tensor @@ -1154,6 +1229,11 @@ gt_(other) -> Tensor In-place version of :meth:`~Tensor.gt` """) +add_docstr_all('has_names', + r""" +Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. +""") + add_docstr_all('hardshrink', r""" hardshrink(lambd=0.5) -> Tensor @@ -3320,6 +3400,24 @@ Example:: """) +add_docstr_all('names', + r""" +Stores names for each of this tensor's dimensions. + +``names[idx]`` corresponds to the name of tensor dimension ``idx``. +Names are either a string if the dimension is named or ``None`` if the +dimension is unnamed. + +Dimension names may contain characters or underscore. Furthermore, a dimension +name must be a valid Python variable name (i.e., does not start with underscore). + +Tensors may not have two named dimensions with the same name. + +.. warning:: + The named tensor API is experimental and subject to change. + +""") + add_docstr_all('is_cuda', r""" Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 6cf93b949c8..239b3dea13f 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -473,6 +473,25 @@ Example:: True """) +add_docstr(torch.angle, + r""" +angle(input, out=None) -> Tensor + +Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = angle(\text{input}_{i}) +""" + r""" +Args: + {input} + {out} + +Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, 325]) +""".format(**common_args)) + add_docstr(torch.as_strided, r""" as_strided(input, size, stride, storage_offset=0) -> Tensor @@ -953,6 +972,25 @@ Example:: tensor([-0., -1., -1., 1.]) """.format(**common_args)) +add_docstr(torch.real, + r""" +real(input, out=None) -> Tensor + +Computes the element-wise real value of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = real(\text{input}_{i}) +""" + r""" +Args: + {input} + {out} + +Example:: + + >>> torch.real(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([ -1, -2, 3]) +""".format(**common_args)) + add_docstr(torch.reciprocal, r""" reciprocal(input, out=None) -> Tensor @@ -1205,6 +1243,25 @@ Example:: tensor([ 0.5000, -0.4702, -0.4599, 0.5000]) """.format(**common_args)) +add_docstr(torch.conj, + r""" +conj(input, out=None) -> Tensor + +Computes the element-wise conjugate of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = conj(\text{input}_{i}) +""" + r""" +Args: + {input} + {out} + +Example:: + + >>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) +""".format(**common_args)) + add_docstr(torch.cos, r""" cos(input, out=None) -> Tensor @@ -2260,6 +2317,25 @@ Example:: tensor([ 0., 2., 1., 0.]) """.format(**common_args)) +add_docstr(torch.imag, + r""" +imag(input, out=None) -> Tensor + +Computes the element-wise imag value of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = imag(\text{input}_{i}) +""" + r""" +Args: + {input} + {out} + +Example:: + + >>> torch.imag(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([ 1, 2, -3]) +""".format(**common_args)) + add_docstr(torch.index_select, r""" index_select(input, dim, index, out=None) -> Tensor diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index a9c47903be8..8fd9a0b49c7 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -123,7 +123,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) { } static PyObject* THPFInfo_eps(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "epsilon", [] { return PyFloat_FromDouble( std::numeric_limits< @@ -132,14 +132,14 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) { } static PyObject* THPFInfo_max(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "max", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "max", [] { return PyFloat_FromDouble( std::numeric_limits::type>::max()); }); } static PyObject* THPFInfo_min(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] { return PyFloat_FromDouble( std::numeric_limits::type>::lowest()); }); @@ -170,7 +170,7 @@ static PyObject* THPIInfo_min(THPFInfo* self, void*) { } static PyObject* THPFInfo_tiny(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] { return PyFloat_FromDouble( std::numeric_limits::type>::min()); }); diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 4fbebed2bef..eac7a3320e1 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -7,7 +7,7 @@ namespace torch { namespace nn{ namespace functional { -inline Tensor elu(Tensor& input, const ELUOptions& options) { +inline Tensor elu(Tensor& input, const ELUOptions& options = {}) { if (options.inplace()) { return torch::elu_(input, options.alpha()); } else { @@ -15,12 +15,20 @@ inline Tensor elu(Tensor& input, const ELUOptions& options) { } } +inline Tensor selu(Tensor& input, const SELUOptions& options = {}) { + if (options.inplace()) { + return torch::selu_(input); + } else { + return torch::selu(input); + } +} + inline Tensor hardshrink(const Tensor& input, - const HardshrinkOptions& options) { + const HardshrinkOptions& options = {}) { return torch::hardshrink(input, options.lambda()); } -inline Tensor hardtanh(Tensor& input, const HardtanhOptions& options) { +inline Tensor hardtanh(Tensor& input, const HardtanhOptions& options = {}) { if (options.inplace()) { return torch::hardtanh_(input, options.min_val(), options.max_val()); } else { @@ -28,7 +36,7 @@ inline Tensor hardtanh(Tensor& input, const HardtanhOptions& options) { } } -inline Tensor leaky_relu(Tensor& input, const LeakyReLUOptions& options) { +inline Tensor leaky_relu(Tensor& input, const LeakyReLUOptions& options = {}) { if (options.inplace()) { return torch::leaky_relu_(input, options.negative_slope()); } else { @@ -40,6 +48,20 @@ inline Tensor logsigmoid(const Tensor& input) { return torch::log_sigmoid(input); } +inline Tensor softmax(const Tensor& input, const SoftmaxOptions& options, + c10::optional dtype = c10::nullopt) { + int64_t dim = options.dim(); + Tensor ret; + + if (dtype == c10::nullopt) { + ret = input.softmax(dim); + } else { + ret = input.softmax(dim, dtype); + } + + return ret; +} + } // namespace functional } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index 72ed3f13330..364100c324f 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -17,6 +17,25 @@ inline Tensor hinge_embedding_loss( options.reduction()); } +inline Tensor multi_margin_loss( + const Tensor& input, + const Tensor& target, + const MultiMarginLossOptions& options = {}) { + TORCH_CHECK(options.p() == 1 || options.p() == 2, "only p == 1 and p == 2 supported"); + if (options.weight().defined()) { + TORCH_CHECK(options.weight().dim() == 1, "weight must be one-dimensional"); + } + + return torch::multi_margin_loss( + input, + target, + options.p(), + options.margin(), + options.weight(), + options.reduction() + ); +} + inline Tensor cosine_embedding_loss( const Tensor& input1, const Tensor& input2, diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index 748980cbb46..31e0cee45fa 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -16,8 +16,7 @@ namespace nn { /// about the exact behavior of this module. class TORCH_API ELUImpl : public torch::nn::Cloneable { public: - ELUImpl() : ELUImpl(ELUOptions()) {} - explicit ELUImpl(const ELUOptions& options_); + explicit ELUImpl(const ELUOptions& options_ = {}); Tensor forward(Tensor& input); @@ -32,6 +31,28 @@ class TORCH_API ELUImpl : public torch::nn::Cloneable { TORCH_MODULE(ELU); +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the selu function element-wise. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.SELU to learn +/// about the exact behavior of this module. +class TORCH_API SELUImpl : public torch::nn::Cloneable { + public: + explicit SELUImpl(const SELUOptions& options_ = {}); + + Tensor forward(Tensor& input); + + void reset() override; + + /// Pretty prints the `SELU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + SELUOptions options; +}; + +TORCH_MODULE(SELU); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the hard shrinkage function element-wise. @@ -39,8 +60,7 @@ TORCH_MODULE(ELU); /// about the exact behavior of this module. class TORCH_API HardshrinkImpl : public torch::nn::Cloneable { public: - HardshrinkImpl() : HardshrinkImpl(HardshrinkOptions()) {} - explicit HardshrinkImpl(const HardshrinkOptions& options_); + explicit HardshrinkImpl(const HardshrinkOptions& options_ = {}); Tensor forward(const Tensor& input); @@ -62,8 +82,7 @@ TORCH_MODULE(Hardshrink); /// about the exact behavior of this module. class TORCH_API HardtanhImpl : public torch::nn::Cloneable { public: - HardtanhImpl() : HardtanhImpl(HardtanhOptions()) {} - explicit HardtanhImpl(const HardtanhOptions& options_); + explicit HardtanhImpl(const HardtanhOptions& options_ = {}); Tensor forward(Tensor& input); @@ -85,8 +104,7 @@ TORCH_MODULE(Hardtanh); /// about the exact behavior of this module. class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable { public: - LeakyReLUImpl() : LeakyReLUImpl(LeakyReLUOptions()) {} - explicit LeakyReLUImpl(const LeakyReLUOptions& options_); + explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {}); Tensor forward(Tensor& input); @@ -108,8 +126,6 @@ TORCH_MODULE(LeakyReLU); /// about the exact behavior of this module. class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable { public: - LogSigmoidImpl() {} - Tensor forward(const Tensor& input); void reset() override; @@ -120,5 +136,28 @@ class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable { TORCH_MODULE(LogSigmoid); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the Softmax function. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.Softmax to learn +/// about the exact behavior of this module. +class TORCH_API SoftmaxImpl : public torch::nn::Cloneable { + public: + explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {} + explicit SoftmaxImpl(const SoftmaxOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softmax` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + SoftmaxOptions options; +}; + +TORCH_MODULE(Softmax); + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index 4c8e60d3f80..7470cf4931d 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -12,19 +12,57 @@ namespace nn { /// Options for the `Embedding` module. struct TORCH_API EmbeddingOptions { - EmbeddingOptions(int64_t count, int64_t dimension); - /// The number of embeddings (number of rows in the table). - TORCH_ARG(int64_t, count); - /// The size of each embedding vector (number of columns in the table). - TORCH_ARG(int64_t, dimension); + EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim) : + num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {}; + /// The size of the dictionary of embeddings. + TORCH_ARG(int64_t, num_embeddings); + /// The size of each embedding vector. + TORCH_ARG(int64_t, embedding_dim); + /// If given, pads the output with the embedding vector at `padding_idx` (initialized to zeros) whenever it encounters the index. + TORCH_ARG(c10::optional, padding_idx) = c10::nullopt; + /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + TORCH_ARG(c10::optional, max_norm) = c10::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(float, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. + TORCH_ARG(bool, sparse) = false; + /// The learnable weights of the module of shape (num_embeddings, embedding_dim) + TORCH_ARG(torch::Tensor, _weight) = Tensor(); +}; + +/// Options for the `EmbeddingBag` module. +struct TORCH_API EmbeddingBagOptions { + EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim) : + num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {}; + /// The size of the dictionary of embeddings. + TORCH_ARG(int64_t, num_embeddings); + /// The size of each embedding vector. + TORCH_ARG(int64_t, embedding_dim); + /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + TORCH_ARG(c10::optional, max_norm) = c10::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(float, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. + /// Note: this option is not supported when ``mode="max"``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. ``"sum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"mean"`` computes the average of the values in the bag, ``"max"`` computes the max value over each bag. + TORCH_ARG(string, mode) = "mean"; + /// If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. + /// Note: this option is not supported when ``mode="max"``. + TORCH_ARG(bool, sparse) = false; + /// The learnable weights of the module of shape (num_embeddings, embedding_dim) + TORCH_ARG(torch::Tensor, _weight) = Tensor(); }; /// Performs a lookup in a fixed size embedding table. class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { public: - EmbeddingImpl(int64_t count, int64_t dimension) - : EmbeddingImpl(EmbeddingOptions(count, dimension)) {} - explicit EmbeddingImpl(EmbeddingOptions options); + EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim) + : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {} + explicit EmbeddingImpl(const EmbeddingOptions& options_); void reset() override; @@ -47,7 +85,64 @@ class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { /// See the documentation for `EmbeddingImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's /// module storage semantics. -TORCH_MODULE(Embedding); +class Embedding : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + static Embedding from_pretrained(const torch::Tensor& embeddings, c10::optional options = c10::nullopt, bool freeze = true) { + TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); + if (options != c10::nullopt) { + TORCH_CHECK((*options).num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", (*options).num_embeddings()); + TORCH_CHECK((*options).embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", (*options).embedding_dim()); + } else { + options = EmbeddingOptions(embeddings.size(0), embeddings.size(1)); + } + Embedding embedding((*options)._weight(embeddings)); + embedding->weight.set_requires_grad(!freeze); + return embedding; + } +}; + +class TORCH_API EmbeddingBagImpl : public torch::nn::Cloneable { + public: + EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim) + : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {} + explicit EmbeddingBagImpl(const EmbeddingBagOptions& options_); + + void reset() override; + + /// Pretty prints the `EmbeddingBag` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + torch::Tensor forward(const Tensor& input, const torch::Tensor& offsets = torch::Tensor(), + const torch::Tensor& per_sample_weights = torch::Tensor()); + + /// The `Options` used to configure this `EmbeddingBag` module. + EmbeddingBagOptions options; + /// The embedding table. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `EmbeddingBagImpl`. +/// See the documentation for `EmbeddingBagImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +class EmbeddingBag : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + + static EmbeddingBag from_pretrained(const torch::Tensor& embeddings, c10::optional options = c10::nullopt, bool freeze = true) { + TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); + if (options != c10::nullopt) { + TORCH_CHECK((*options).num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", (*options).num_embeddings()); + TORCH_CHECK((*options).embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", (*options).embedding_dim()); + } else { + options = EmbeddingBagOptions(embeddings.size(0), embeddings.size(1)); + } + EmbeddingBag embeddingbag((*options)._weight(embeddings)); + embeddingbag->weight.set_requires_grad(!freeze); + return embeddingbag; + } +}; } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h index ddd7e94c69b..e5775b621e9 100644 --- a/torch/csrc/api/include/torch/nn/modules/loss.h +++ b/torch/csrc/api/include/torch/nn/modules/loss.h @@ -63,6 +63,33 @@ TORCH_MODULE(HingeEmbeddingLoss); // ============================================================================ +/// Creates a criterion that optimizes a multi-class classification hinge +/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and +/// output :math:`y` (which is a 1D tensor of target class indices, +/// :math:`0 \leq y \leq \text{x.size}(1)-1`): +struct TORCH_API MultiMarginLossImpl : Module { + explicit MultiMarginLossImpl( + const MultiMarginLossOptions& options_ = {}); + + void reset(); + + /// Pretty prints the `MultiMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiMarginLossImpl`. +/// See the documentation for `MultiMarginLossImpl` class to learn what +/// methods it provides, or the documentation for `ModuleHolder` to learn about +/// PyTorch's module storage semantics. +TORCH_MODULE(MultiMarginLoss); + +// ============================================================================ + /// Creates a criterion that measures the loss given input tensors /// `input1`, `input2`, and a `Tensor` label `target` with values 1 or /// -1. This is used for measuring whether two inputs are similar or diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index 592f3c61e1c..7ff1da836fe 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -2,15 +2,14 @@ #include #include +#include namespace torch { namespace nn { /// Options for ELU functional and module. -struct ELUOptions { - ELUOptions() {} - - /// The alpha value for the ELU formulation. Default: 1.0 +struct TORCH_API ELUOptions { + /// The `alpha` value for the ELU formulation. Default: 1.0 TORCH_ARG(double, alpha) = 1.0; /// can optionally do the operation in-place. Default: False @@ -19,20 +18,28 @@ struct ELUOptions { // ============================================================================ +/// Options for SELU functional and module. +struct TORCH_API SELUOptions { + /* implicit */ SELUOptions(bool inplace = false); + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace); +}; + +// ============================================================================ + /// Options for Hardshrink functional and module. struct TORCH_API HardshrinkOptions { /* implicit */ HardshrinkOptions(double lambda = 0.5); - /// the lambda value for the Hardshrink formulation. Default: 0.5 + /// the `lambda` value for the Hardshrink formulation. Default: 0.5 TORCH_ARG(double, lambda); }; // ============================================================================ /// Options for Hardtanh functional and module. -struct HardtanhOptions { - HardtanhOptions() {} - +struct TORCH_API HardtanhOptions { /// minimum value of the linear region range. Default: -1 TORCH_ARG(double, min_val) = -1.0; @@ -46,9 +53,7 @@ struct HardtanhOptions { // ============================================================================ /// Options for LeakyReLU functional and module. -struct LeakyReLUOptions { - LeakyReLUOptions() {} - +struct TORCH_API LeakyReLUOptions { /// Controls the angle of the negative slope. Default: 1e-2 TORCH_ARG(double, negative_slope) = 1e-2; @@ -56,5 +61,15 @@ struct LeakyReLUOptions { TORCH_ARG(bool, inplace) = false; }; +// ============================================================================ + +/// Options for the Softmax functional and module. +struct TORCH_API SoftmaxOptions { + SoftmaxOptions(int64_t dim); + + // Dimension along which Softmax will be computed. + TORCH_ARG(int64_t, dim); +}; + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index 1cae3316ac6..96e3c7461d5 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -29,6 +29,26 @@ struct TORCH_API HingeEmbeddingLossOptions { // ============================================================================ +/// Options for a multi-margin loss functional and module. +struct TORCH_API MultiMarginLossOptions { + /// Has a default value of :math:`1`. :math:`1` and :math:`2` + /// are the only supported values. + TORCH_ARG(int64_t, p) = 1; + /// Has a default value of :math:`1`. + TORCH_ARG(double, margin) = 1.0; + /// A manual rescaling weight given to each + /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is + /// treated as if having all ones. + TORCH_ARG(Tensor, weight) = Tensor(); + /// Specifies the reduction to apply to the output: + /// ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + /// ``'mean'``: the sum of the output will be divided by the number of + /// elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + TORCH_ARG(Reduction::Reduction, reduction) = Reduction::Mean; +}; + +// ============================================================================ + /// Options for a Hinge Embedding loss functional and module. struct TORCH_API CosineEmbeddingLossOptions { /// Specifies the threshold for which the distance of a negative sample must diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index ddbb493896d..779c842fb10 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -24,6 +24,24 @@ void ELUImpl::pretty_print(std::ostream& stream) const { // ============================================================================ +SELUImpl::SELUImpl(const SELUOptions& options_) : options(options_) {} + +Tensor SELUImpl::forward(Tensor& input) { + return F::selu(input, options); +} + +void SELUImpl::reset() {} + +void SELUImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::SELU("; + if (options.inplace()) { + stream << std::boolalpha << "inplace=" << options.inplace(); + } + stream << ")"; +} + +// ============================================================================ + HardshrinkImpl::HardshrinkImpl(const HardshrinkOptions& options_) : options(options_) {} @@ -96,5 +114,20 @@ void LogSigmoidImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::LogSigmoid()"; } +// ============================================================================ + +SoftmaxImpl::SoftmaxImpl(const SoftmaxOptions& options_) + : options(options_) {} + +void SoftmaxImpl::reset() {} + +void SoftmaxImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::Softmax(dim=" << options.dim() << ")"; +} + +Tensor SoftmaxImpl::forward(const Tensor& input) { + return F::softmax(input, options); +} + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 16d4c95a854..838f63789cb 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -10,28 +11,177 @@ namespace torch { namespace nn { - -EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension) - : count_(count), dimension_(dimension) {} - -EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) { +EmbeddingImpl::EmbeddingImpl(const EmbeddingOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) reset(); } void EmbeddingImpl::reset() { - weight = register_parameter( - "weight", torch::empty({options.count(), options.dimension()})); - NoGradGuard guard; - weight.normal_(0, 1); + if (options.padding_idx() != c10::nullopt) { + if (*options.padding_idx() > 0) { + TORCH_CHECK(*options.padding_idx() < options.num_embeddings(), "Padding_idx must be within num_embeddings"); + } + else if (*options.padding_idx() < 0) { + TORCH_CHECK(*options.padding_idx() >= -(options.num_embeddings()), "Padding_idx must be within num_embedding"); + options.padding_idx(options.num_embeddings() + *options.padding_idx()); + } + } + + if (!options._weight().defined()) { + weight = register_parameter( + "weight", torch::empty({options.num_embeddings(), options.embedding_dim()})); + torch::nn::init::normal_(weight); + if (options.padding_idx() != c10::nullopt) { + torch::NoGradGuard no_grad; + weight[*options.padding_idx()].fill_(0); + } + } else { + TORCH_CHECK(options._weight().sizes() == torch::IntArrayRef({options.num_embeddings(), options.embedding_dim()}), "Shape of _weight does not match num_embeddings and embedding_dim"); + weight = register_parameter("weight", options._weight()); + } } void EmbeddingImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::Embedding(count=" << options.count() - << ", dimension=" << options.dimension() << ")"; + stream << "torch::nn::Embedding(num_embeddings=" << options.num_embeddings() + << ", embedding_dim=" << options.embedding_dim(); + if (options.padding_idx() != c10::nullopt) { + stream << ", padding_idx=" << *options.padding_idx(); + } + if (options.max_norm() != c10::nullopt) { + stream << ", max_norm=" << *options.max_norm(); + } + if (options.norm_type() != 2) { + stream << ", norm_type=" << options.norm_type(); + } + if (options.scale_grad_by_freq()) { + stream << ", scale_grad_by_freq=" << std::boolalpha << options.scale_grad_by_freq(); + } + if (options.sparse()) { + stream << ", sparse=" << std::boolalpha << options.sparse(); + } + stream << ")"; } -Tensor EmbeddingImpl::forward(const Tensor& input) { - return torch::embedding(weight, /*indices=*/input); +torch::Tensor EmbeddingImpl::forward(const Tensor& input) { + if (options.padding_idx() != c10::nullopt) { + if (*options.padding_idx() > 0) { + TORCH_CHECK(*options.padding_idx() < weight.size(0), "Padding_idx must be within num_embeddings"); + } + else if (*options.padding_idx() < 0) { + TORCH_CHECK(*options.padding_idx() >= -weight.size(0), "Padding_idx must be within num_embedding"); + options.padding_idx(weight.size(0) + *options.padding_idx()); + } + } else { + options.padding_idx(-1); + } + + if (options.max_norm() != c10::nullopt) { + torch::NoGradGuard no_grad; + torch::embedding_renorm_(weight, input.contiguous(), *options.max_norm(), options.norm_type()); + } + return torch::embedding(weight, input.contiguous(), *options.padding_idx(), options.scale_grad_by_freq(), options.sparse()); +} + +EmbeddingBagImpl::EmbeddingBagImpl(const EmbeddingBagOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) + reset(); +} + +void EmbeddingBagImpl::reset() { + if (!options._weight().defined()) { + weight = register_parameter( + "weight", torch::empty({options.num_embeddings(), options.embedding_dim()})); + torch::nn::init::normal_(weight); + } else { + TORCH_CHECK( + options._weight().sizes() == torch::IntArrayRef({options.num_embeddings(), options.embedding_dim()}), + "Shape of weight does not match num_embeddings and embedding_dim"); + weight = register_parameter("weight", options._weight()); + } +} + +torch::Tensor EmbeddingBagImpl::forward( + const torch::Tensor& input, + const torch::Tensor& offsets, + const torch::Tensor& per_sample_weights) { + torch::Tensor input_ = input; + torch::Tensor offsets_ = offsets; + torch::Tensor per_sample_weights_ = per_sample_weights; + + TORCH_CHECK(!per_sample_weights_.defined() || input.sizes() == per_sample_weights_.sizes(), + "embedding_bag: If per_sample_weights (", per_sample_weights_.sizes(), ") is not null, then it must have the same shape as the input (", input.sizes(), ")"); + if (input.dim() == 2) { + TORCH_CHECK(!offsets_.defined(), + "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor"); + offsets_ = torch::arange(0, input.numel(), input.size(1), + torch::TensorOptions().dtype(torch::kLong).device(input.device())); + input_ = input_.reshape(-1); + if (per_sample_weights_.defined()) { + per_sample_weights_ = per_sample_weights_.reshape(-1); + } + } else if (input.dim() == 1) { + TORCH_CHECK(offsets_.defined(), "offsets has to be a 1D Tensor but got null"); + TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor"); + TORCH_CHECK(offsets_[0].item() == 0, "offsets[0] has to be 0, i.e., the first sequence in the mini-batch has to start from position 0. However, got ", + offsets_[0].item()); + TORCH_CHECK(offsets_[-1].item() <= input.size(0), "offsets[-1] can not be greater than input's length({)", + input.size(0), "}), but got offsets[-1] of {", offsets_[-1].item(), "}"); + } else { + TORCH_CHECK(false, "input has to be 1D or 2D Tensor,but got Tensor of dimension ", input.dim()); + } + + int mode_enum; + if (options.mode() == "sum") { + mode_enum = 0; + } else if (options.mode() == "mean") { + mode_enum = 1; + } else if (options.mode() =="max") { + mode_enum = 2; + TORCH_CHECK(!options.scale_grad_by_freq(), "max mode does not support scaling the gradient by the frequency"); + TORCH_CHECK(!options.sparse(), "max mode does not support sparse weights"); + } else { + TORCH_CHECK(false, "mode has to be one of sum, mean or max"); + } + + if (options.max_norm() != c10::nullopt) { + torch::NoGradGuard no_grad; + torch::embedding_renorm_(weight, input_, *options.max_norm(), options.norm_type()); + } + + TORCH_CHECK( + !per_sample_weights_.defined() || options.mode() == "sum", + "embedding_bag: per_sample_weights was not null. ", + "per_sample_weights is only supported for mode='sum' (got mode='", + options.mode(), "').Please open a feature request on GitHub."); + + return std::get<0>( + torch::embedding_bag( + weight, + input_, + offsets_, + options.scale_grad_by_freq(), + mode_enum, + options.sparse(), + per_sample_weights_)); +} + +void EmbeddingBagImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::EmbeddingBag(num_embeddings=" << options.num_embeddings() + << ", embedding_dim=" << options.embedding_dim(); + if (options.max_norm() != c10::nullopt) { + stream << ", max_norm=" << *options.max_norm(); + } + if (options.norm_type() != 2) { + stream << ", norm_type=" << options.norm_type(); + } + if (options.scale_grad_by_freq()) { + stream << ", scale_grad_by_freq=" << std::boolalpha << options.scale_grad_by_freq(); + } + if (options.sparse()) { + stream << ", sparse=" << std::boolalpha << options.sparse(); + } + if (options.mode() != "mean") { + stream << ", mode=" << options.mode(); + } + stream << ")"; } } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp index dec096183b1..d00ea436a6b 100644 --- a/torch/csrc/api/src/nn/modules/loss.cpp +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -34,6 +34,34 @@ Tensor HingeEmbeddingLossImpl::forward( // ============================================================================ +MultiMarginLossImpl::MultiMarginLossImpl( + const MultiMarginLossOptions& options_) // NOLINT(modernize-pass-by-value) + : options(options_) { + reset(); + } + +void MultiMarginLossImpl::reset() { + TORCH_CHECK((options.p() == 1) || (options.p() == 2), "only p == 1 and p == 2 supported"); + TORCH_CHECK(!options.weight().defined() || options.weight().dim() == 1); + + register_buffer("weight", options.weight()); +} + +void MultiMarginLossImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::MultiMarginLoss(p=" << options.p() << + ", margin=" << options.margin() << + ", weight=" << options.weight() << + ", reduction=" << options.reduction() << ")"; +} + +Tensor MultiMarginLossImpl::forward( + const Tensor& input, + const Tensor& target) { + return F::multi_margin_loss(input, target, options); +} + +// ============================================================================ + CosineEmbeddingLossImpl::CosineEmbeddingLossImpl( const CosineEmbeddingLossOptions& options_) : options(options_) {} diff --git a/torch/csrc/api/src/nn/options/activation.cpp b/torch/csrc/api/src/nn/options/activation.cpp index 9e2cf801cc4..41f5b78d943 100644 --- a/torch/csrc/api/src/nn/options/activation.cpp +++ b/torch/csrc/api/src/nn/options/activation.cpp @@ -3,7 +3,11 @@ namespace torch { namespace nn { +SELUOptions::SELUOptions(bool inplace) : inplace_(inplace) {} + HardshrinkOptions::HardshrinkOptions(double lambda) : lambda_(lambda) {} +SoftmaxOptions::SoftmaxOptions(int64_t dim) : dim_(dim) {} + } // namespace nn } // namespace torch diff --git a/torch/csrc/jit/mobile/register_mobile_ops.cpp b/torch/csrc/jit/mobile/register_mobile_ops.cpp index 07875a8da69..5d26a130ebf 100644 --- a/torch/csrc/jit/mobile/register_mobile_ops.cpp +++ b/torch/csrc/jit/mobile/register_mobile_ops.cpp @@ -7,219 +7,31 @@ using torch::jit::peek; using torch::jit::drop; using torch::jit::pack; -namespace { -at::Tensor toOptionalTensor(const c10::IValue& v) { - if (v.isNone()) { - return at::Tensor(); - } - return v.toTensor(); -} - at::Tensor optional_to_tensor(c10::optional v) { return v.has_value() ? *v : at::Tensor(); } -} static auto registry0 = torch::RegisterOperators().op( "_aten::add.Tensor", torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, at::Tensor b, at::Scalar c) -> at::Tensor { + [](at::Tensor a, at::Tensor b, at::Scalar c) ->at::Tensor { return at::add(a, b, c); }) ).op( "_aten::add.Scalar", torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, at::Scalar b, at::Scalar c) -> at::Tensor { + [](at::Tensor a, at::Scalar b, at::Scalar c) ->at::Tensor { return at::add(a, b, c); }) ).op( - "_aten::add_.Tensor", + "_aten::_convolution", torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, at::Tensor b, at::Scalar c) -> at::Tensor { - return at::add(a, b, c); - }) -).op( - "_aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { - #ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); - #endif - auto result_ = at::adaptive_avg_pool2d( - (std::move(peek(*stack, 0, 2))).toTensor(), - (std::move(peek(*stack, 1, 2))).toIntListRef() - ); - drop(*stack, 2); - pack(*stack, std::move(result_)); - }) -).op( - "_aten::mm", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, at::Tensor b) -> at::Tensor { - return at::mm(a, b); - }) -).op( - "_aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); -#endif - auto result_ = at::_convolution( - (std::move(peek(*stack, 0, 12))).toTensor(), - (std::move(peek(*stack, 1, 12))).toTensor(), - toOptionalTensor((std::move(peek(*stack, 2, 12)))), - (std::move(peek(*stack, 3, 12))).toIntListRef(), - (std::move(peek(*stack, 4, 12))).toIntListRef(), - (std::move(peek(*stack, 5, 12))).toIntListRef(), - (std::move(peek(*stack, 6, 12))).toBool(), - (std::move(peek(*stack, 7, 12))).toIntListRef(), - (std::move(peek(*stack, 8, 12))).toInt(), - (std::move(peek(*stack, 9, 12))).toBool(), - (std::move(peek(*stack, 10, 12))).toBool(), - (std::move(peek(*stack, 11, 12))).toBool() - ); - drop(*stack, 12); - pack(*stack, std::move(result_)); - }) -).op( - "_aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); -#endif - auto result_ = at::conv2d( - (std::move(peek(*stack, 0, 7))).toTensor(), - (std::move(peek(*stack, 1, 7))).toTensor(), - toOptionalTensor((std::move(peek(*stack, 2, 7)))), - (std::move(peek(*stack, 3, 7))).toIntListRef(), - (std::move(peek(*stack, 4, 7))).toIntListRef(), - (std::move(peek(*stack, 5, 7))).toIntListRef(), - (std::move(peek(*stack, 6, 7))).toInt() - ); - drop(*stack, 7); - pack(*stack, std::move(result_)); - }) -).op( - "_aten::batch_norm", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [] (at::Tensor input, c10::optional weight, c10::optional bias, - c10::optional running_mean, c10::optional running_var, - bool training, double momentum, double eps, bool cudnn_enabled) { - return at::batch_norm(input, optional_to_tensor(weight), optional_to_tensor(bias), - optional_to_tensor(running_mean), optional_to_tensor(running_var), - training, momentum, eps, cudnn_enabled); - }) -).op( - "_aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { - #ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); - #endif - auto result_ = at::max_pool2d_with_indices( - (std::move(peek(*stack, 0, 6))).toTensor(), - (std::move(peek(*stack, 1, 6))).toIntListRef(), - (std::move(peek(*stack, 2, 6))).toIntListRef(), - (std::move(peek(*stack, 3, 6))).toIntListRef(), - (std::move(peek(*stack, 4, 6))).toIntListRef(), - (std::move(peek(*stack, 5, 6))).toBool() - ); - drop(*stack, 6); - pack(*stack, std::move(result_)); - }) -).op( - "_aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); -#endif - auto result_ = at::max_pool2d( - (std::move(peek(*stack, 0, 6))).toTensor(), - (std::move(peek(*stack, 1, 6))).toIntListRef(), - (std::move(peek(*stack, 2, 6))).toIntListRef(), - (std::move(peek(*stack, 3, 6))).toIntListRef(), - (std::move(peek(*stack, 4, 6))).toIntListRef(), - (std::move(peek(*stack, 5, 6))).toBool() - ); - drop(*stack, 6); - pack(*stack, std::move(result_)); - }) -).op( - "_aten::threshold", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor self, at::Scalar threshold, at::Scalar value) { - return at::threshold_(self, threshold, value); - }) -).op( - "_aten::relu", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId) -).op( - "_aten::relu_", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a) -> at::Tensor { - return at::relu_(a); - }) -).op( - "_aten::t", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId) -).op( - "_aten::size.int", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, int64_t dim) -> int64_t { - return at::size(a, dim); - }) -).op( - "_aten::addmm", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId) -).op( - "_aten::view(Tensor(a) self, int[] size) -> Tensor(a)", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](c10::OperatorKernel* kernel, Stack* stack) { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode non_var_type_mode(true); -#endif - auto result_ = ((std::move(peek(*stack, 0, 2))).toTensor()).view( - (std::move(peek(*stack, 1, 2))).toIntListRef() - ); - drop(*stack, 2); - pack(*stack, std::move(result_)); - }).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA) -).op( - "_aten::dim", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a) -> int64_t { - return a.dim(); - }) -).op( - "_aten::eq", - torch::RegisterOperators::options().catchAllKernel( - [](int64_t a, int64_t b) -> bool { - return a == b; - }) -).op( - "_aten::log_softmax", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a, int64_t b, c10::optional c) -> at::Tensor { - if (c.has_value()) { - return at::log_softmax(a, b, static_cast(c.value())); - } else { - return at::log_softmax(a, b); - } - }) -).op( - "_aten::Int", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, - [](at::Tensor a) -> int64_t { - return a.item(); - }) -).op( - "_prim::NumToTensor", - torch::RegisterOperators::options().catchAllKernel( - [](at::Scalar s) -> at::Tensor { - return at::scalar_to_tensor(s); + [](at::Tensor input, at::Tensor weight, c10::optional bias, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, std::vector output_padding, + int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_convolution(input, weight, optional_to_tensor(bias), stride, padding, dilation, + transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); }) ).op( // Dummy operator that does nothing. Used to reserve a location of an operator table. diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp index 275b716bfe7..4ed98a5a090 100644 --- a/torch/csrc/jit/script/python_sugared_value.cpp +++ b/torch/csrc/jit/script/python_sugared_value.cpp @@ -660,7 +660,8 @@ std::shared_ptr toSugaredValue( // methods here have been explicitly annotated to not be compiled, // so they do not have the same overload and compile checks as for functions if (isFunction || isMethod) { - auto rcb = py::module::import("torch.jit").attr("_gen_rcb")(obj, 0); + auto rcb = py::module::import("torch._jit_internal") + .attr("createResolutionCallbackFromClosure")(obj); return std::make_shared(obj, rcb); } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 72ee285e7be..c94ea23ddbf 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -114,10 +114,10 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): This is different from :func:`load `'s semantics and may change in the future. Arguments: - m: a ScriptModule to save - f: a file-like object (has to implement write and flush) or a string - containing a file name - _extra_files: Map from filename to contents which will be stored as part of 'f' + m: A ScriptModule to save. + f: A file-like object (has to implement write and flush) or a string + containing a file name. + _extra_files: Map from filename to contents which will be stored as part of 'f'. .. warning:: If you are using Python 2, ``torch.jit.save`` does NOT support ``StringIO.StringIO`` @@ -164,7 +164,8 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): r""" - Load a ``ScriptModule`` previously saved with :func:`torch.jit.save ` + Load a :class:`ScriptModule` or :class:`ScriptFunction` previously + saved with :func:`torch.jit.save ` All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from. If this fails (e.g. because @@ -180,7 +181,7 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): would be stored in the provided map. Returns: - A ``ScriptModule`` object. + A :class:`ScriptModule` object. Example: @@ -726,15 +727,16 @@ def trace(func, _module_class=None, _compilation_unit=_python_cu): """ - Trace a function and return an executable ``ScriptModule`` or ``torch.jit.ScriptFunction`` - that will be optimized using just-in-time compilation. + Trace a function and return an executable or :class:`ScriptFunction` + that will be optimized using just-in-time compilation. Tracing is ideal for + code that operates only on ``Tensor``\\s and lists, dictionaries, and tuples of ``Tensor``\\s. Using ``torch.jit.trace`` and :func:`torch.jit.trace_module`, you can turn an existing module or Python - function into a TorchScript ``torch.jit.ScriptFunction`` or ``ScriptModule``. You must provide example inputs, + function into a TorchScript :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example inputs, and we run the function, recording the operations performed on all the tensors. - * The resulting recording of a standalone function produces ``torch.jit.ScriptFunction``. - * The resulting recording of ``forward`` function of ``nn.Module`` or ``nn.Module`` produces ``ScriptModule``. + * The resulting recording of a standalone function produces :class:`ScriptFunction`. + * The resulting recording of ``forward`` function of ``nn.Module`` or ``nn.Module`` produces :class:`ScriptModule`. This module also contains any parameters that the original module had as well. @@ -745,7 +747,7 @@ def trace(func, any untracked external dependencies (e.g., perform input/output or access global variables). Tracing only records operations done when the given function is run on the given - tensors. Therefore, the returned ``ScriptModule`` will always run the same traced + tensors. Therefore, the returned :class:`ScriptModule` will always run the same traced graph on any input. This has some important implications when your module is expected to run different sets of operations, depending on the input and/or the module state. For example, @@ -755,10 +757,10 @@ def trace(func, inlines the control-flow decisions. But sometimes the control-flow is actually part of the model itself. For instance, a recurrent network is a loop over the (possibly dynamic) length of an input sequence. - * In the returned ``ScriptModule``, operations that have different + * In the returned :class:`ScriptModule`, operations that have different behaviors in ``training`` and ``eval`` modes will always behave as if it is in the mode it was in during tracing, no matter which mode the - ``ScriptModule`` is in. + :class:`ScriptModule` is in. In cases like these, tracing would not be appropriate and :func:`scripting ` is a better choice. If you trace such models, you may silently get @@ -767,19 +769,22 @@ def trace(func, incorrect trace to be produced. Arguments: - func (callable or torch.nn.Module): a Python function or ``torch.nn.Module`` + func (callable or torch.nn.Module): A Python function or ``torch.nn.Module`` that will be run with ``example_inputs``. arguments and returns to ``func`` must be tensors or (possibly nested) tuples that - contain tensors. - example_inputs (tuple): a tuple of example inputs that will be passed to the function + contain tensors. When a module is passed to + :func:`torch.jit.trace `, only the + ``forward`` method is run and traced + (see :func:`torch.jit.trace ` for details). + example_inputs (tuple): A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. ``example_inputs`` may also be a single - Tensor in which case it is automatically wrapped in a tuple + Tensor in which case it is automatically wrapped in a tuple. Keyword arguments: - check_trace (bool, optional): check if the same inputs run through + check_trace (bool, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite @@ -797,11 +802,11 @@ def trace(func, results diverge numerically for a known reason, such as operator fusion. Returns: - if ``callable`` is ``nn.Module`` or ``forward()`` of ``nn.Module``, ``trace`` returns - a ``ScriptModule`` object with a single ``forward()`` method containing the traced code. - The returned ``ScriptModule`` will have the same set of sub-modules and parameters as the + If ``callable`` is ``nn.Module`` or ``forward`` of ``nn.Module``, ``trace`` returns + a :class:`ScriptModule` object with a single ``forward`` method containing the traced code. + The returned :class:`ScriptModule` will have the same set of sub-modules and parameters as the original ``nn.Module``. - If ``callable`` is a standalone function, ``trace`` returns ``torch.jit.ScriptFunction`` + If ``callable`` is a standalone function, ``trace`` returns :class:`ScriptFunction` Example (tracing a function): @@ -842,6 +847,7 @@ def trace(func, # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) + """ if not _enabled: return func @@ -904,7 +910,7 @@ def trace_module(mod, _module_class=None, _compilation_unit=_python_cu): """ - Trace a module and return an executable ``ScriptModule`` that will be optimized + Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation. When a module is passed to :func:`torch.jit.trace `, only the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of method names to example inputs to trace (see the ``example_inputs``) argument below. @@ -912,15 +918,15 @@ def trace_module(mod, See :func:`torch.jit.trace ` for more information on tracing. Arguments: - mod (torch.nn.Module): a ``torch.nn.Module`` containing methods whose names are - specified in ``example_inputs``. The given methods will be compiled - as a part of a single `ScriptModule` - example_inputs (dict): a dict containing sample inputs indexed by method names in ``mod`` - The inputs will be passed to methods whose names correspond to inputs' - keys while tracing. - ``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` + mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are + specified in ``example_inputs``. The given methods will be compiled + as a part of a single `ScriptModule`. + example_inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. + The inputs will be passed to methods whose names correspond to inputs' + keys while tracing. + ``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` Keyword arguments: - check_trace (bool, optional): check if the same inputs run through + check_trace (bool, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite @@ -938,8 +944,8 @@ def trace_module(mod, results diverge numerically for a known reason, such as operator fusion. Returns: - A ``ScriptModule`` object with a single ``forward()`` method containing the traced code. - When ``func`` is a ``torch.nn.Module``, the returned ``ScriptModule`` will have the same set of + A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. + When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of sub-modules and parameters as ``func``. Example (tracing a module with multiple methods):: @@ -975,6 +981,7 @@ def trace_module(mod, # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs) + """ if not _enabled: return mod @@ -1019,7 +1026,7 @@ class CompilationUnit(object): def define(self, lang, rcb=None, _frames_up=0): if not rcb: - rcb = _jit_internal.createResolutionCallback(_frames_up + 1) + rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) self._c.define(lang, rcb) def __getattr__(self, attr): @@ -1081,8 +1088,8 @@ def _compile_and_register_class(obj, rcb, qualified_name): def script(obj, optimize=None, _frames_up=0, _rcb=None): r""" Scripting a function or ``nn.Module`` will inspect the source code, compile - it as TorchScript code using the TorchScript compiler, and return a ``ScriptModule`` or - ``torch.jit.ScriptFunction``. TorchScript itself is a subset of the Python language, so not all + it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or + :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the `TorchScript Language Reference`_. @@ -1091,7 +1098,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): ``@torch.jit.script`` for `TorchScript Classes `_ and functions. **Scripting a function** - The ``@torch.jit.script`` decorator will construct a ``torch.jit.ScriptFunction`` + The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` by compiling the body of the function. Example (scripting a function): @@ -1108,11 +1115,19 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): r = y return r + print(type(foo)) # torch.jit.ScriptFuncion + + # See the compiled graph as Python code + print(foo.code) + + # Call the function using the TorchScript interpreter + foo(torch.ones(2, 2), torch.ones(2, 2)) + **Scripting an nn.Module** Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses features supported in TorchScript, no changes to the original module code should be necessary. ``script`` - will construct ``torch.jit.ScriptModule`` that has copies of the attributes, parameters, and methods of + will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of the original module. Example (scripting a simple module with a Parameter): @@ -1217,37 +1232,19 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): raise RuntimeError("TorchScript classes must be new-style classes. " "Please inherit from 'object'") if _rcb is None: - _rcb = _jit_internal.createResolutionCallback(_frames_up + 1) + _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) _compile_and_register_class(obj, _rcb, qualified_name) return obj else: _check_directly_compile_overloaded(obj) ast = get_jit_def(obj) if _rcb is None: - _rcb = _gen_rcb(obj, _frames_up) + _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj)) # Forward docstrings fn.__doc__ = obj.__doc__ return fn -def _gen_rcb(obj, _frames_up): - _frames_up = _frames_up + 1 # for invoking _gen_rcb() - - closure_rcb = _jit_internal.createResolutionCallbackFromClosure(obj) - stack_rcb = _jit_internal.createResolutionCallback(_frames_up + 1) - - def _rcb(name): - # since type comments aren't captured in the function's closures, - # we still need to try to the rcb based on stack frames if the - # closure rcb fails - result = closure_rcb(name) - if result: - return result - return stack_rcb(name) - - return _rcb - - def interface(obj): if not inspect.isclass(obj): raise RuntimeError("interface must be applied to a class") @@ -1255,7 +1252,7 @@ def interface(obj): raise RuntimeError("TorchScript interfaces must inherit from 'object'") qualified_name = _qualified_name(obj) ast = get_jit_class_def(obj, obj.__name__) - rcb = _jit_internal.createResolutionCallback(1) + rcb = _jit_internal.createResolutionCallbackFromFrame(1) torch._C._jit_script_interface_compile(qualified_name, ast, rcb) obj.__torch_script_interface__ = True return obj @@ -1279,7 +1276,7 @@ def script_method(fn, _rcb=None): # createResolutionCallback internally adds 1 to get us to the scope of this # function (the calling function). Adding 2 gets us to the proper surrounding scope. if _rcb is None: - _rcb = _jit_internal.createResolutionCallback(frames_up=2) + _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) ast = get_jit_def(fn, self_name="ScriptModule") return ScriptMethodStub(_rcb, ast, fn) @@ -1499,6 +1496,15 @@ if _enabled: ``ScriptModule``\s should not be created manually, instead use either :func:`tracing ` or :func:`scripting `. + Tracing and scripting can be applied incrementally and :ref:`composed as necessary `. + + * Tracing records the tensor operations as executed with a set of example inputs and uses these + operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing, + but values other than Tensors and control flow aren't captured in the graph. + + * Scripting inspects the Python code of the model + and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow. + Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary. """ def __init__(self, optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None): if _qualified_name is None: @@ -1640,7 +1646,7 @@ if _enabled: # # createResolutionCallback internally adds 1 to get us to our frame, then # we add 1 to get to the proper surrounding scope. - rcb = _jit_internal.createResolutionCallback(frames_up=1) + rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) self._c._define(self, lang, rcb) def copy(self): @@ -2013,7 +2019,7 @@ _compiled_overloaded_fns = {} def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults): impl_ast = torch.jit.get_jit_def(impl_fn) _frames_up = 0 - _rcb = _gen_rcb(impl_fn, _frames_up) + _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults) return fn @@ -2118,6 +2124,10 @@ def _graph_for(self, *args, **kwargs): torch._C.ScriptMethod.graph_for = _graph_for torch._C.ScriptFunction.graph_for = _graph_for ScriptFunction = torch._C.ScriptFunction +ScriptFunction.__doc__ = """ +Functionally equivalent to a :class:`ScriptModule`, but represents a single +function and does not have any attributes or Parameters. +""" set_module(ScriptFunction, "torch.jit") if not torch._C._jit_init(): diff --git a/torch/tensor.py b/torch/tensor.py index 4c45299700d..f287d9da290 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -522,18 +522,110 @@ class Tensor(torch._C._TensorBase): return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=1) def refine_names(self, *names): + r"""Refines the dimension names of :attr:`self` according to :attr:`names`. + + Refining is a special case of renaming that "lifts" unnamed dimensions. + A ``None`` dim can be refined to have any name; a named dim can only be + refined to have the same name. + + Because named tensors can coexist with unnamed tensors, refining names + gives a nice way to write named-tensor-aware code that works with both + named and unnamed tensors. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded greedily; it is expanded in-place to fill + :attr:`names` to the same length as ``self.dim()`` using names from the + corresponding indices of ``self.names``. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Arguments: + names (iterable of str): The desired names of the output tensor. May + contain up to one Ellipsis. + + Examples:: + + >>> imgs = torch.randn(32, 3, 128, 128) + >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') + >>> named_imgs.names + ('N', 'C', 'H', 'W') + + >>> tensor = torch.randn(2, 3, 5, 7, 11) + >>> tensor = tensor.refine_names('A', ..., 'B', 'C') + >>> tensor.names + ('A', None, None, 'B', 'C') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ names = resolve_ellipsis(names, self.names, 'refine_names') return super(Tensor, self).refine_names(names) def align_to(self, *names): + r"""Permutes the dimensions of the :attr:`self` tensor to match the order + specified in :attr:`names`, adding size-one dims for any new names. + + All of the dims of :attr:`self` must be named in order to use this method. + The resulting tensor is a view on the original tensor. + + All dimension names of :attr:`self` must be present in :attr:`names`. + :attr:`names` may contain additional names that are not in ``self.names``; + the output tensor has a size-one dimension for each of those new names. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded to be equal to all dimension names of :attr:`self` + that are not mentioned in :attr:`names`, in the order that they appear + in :attr:`self`. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Arguments: + names (iterable of str): The desired dimension ordering of the + output tensor. May contain up to one Ellipsis that is expanded + to all unmentioned dim names of :attr:`self`. + + Examples:: + + >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) + >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') + + # Move the F and E dims to the front while keeping the rest in order + >>> named_tensor.align_to('F', 'E', ...) + + .. warning:: + The named tensor API is experimental and subject to change. + + """ return super(Tensor, self).align_to( resolve_ellipsis(names, self.names, 'align_to', is_positional=False)) def unflatten(self, dim, namedshape): + r"""Unflattens the named dimension :attr:`dim`, viewing it in the shape + specified by :attr:`namedshape`. + + Arguments: + namedshape: (iterable of ``(name, size)`` tuples). + + Examples:: + + >>> flat_imgs = torch.rand(32, 3 * 128 * 128, names=('N', 'features')) + >>> imgs = flat_imgs.unflatten('features', (('C', 3), ('H', 128), ('W', 128))) + >>> imgs.names, images.shape + (('N', 'C', 'H', 'W'), torch.Size([32, 3, 128, 128])) + + .. warning:: + The named tensor API is experimental and subject to change. + + """ names, sizes = unzip_namedshape(namedshape) return super(Tensor, self).unflatten(dim, sizes, names) def rename_(self, *names, **rename_map): + """In-place version of :meth:`~Tensor.rename`.""" + # Note [rename_ / rename API] # The Python API for these is different from the C++ API. In Python: # 1) tensor.rename(*names) takes a vararglist of names @@ -542,6 +634,39 @@ class Tensor(torch._C._TensorBase): return update_names(self, names, rename_map, inplace=True) def rename(self, *names, **rename_map): + """Renames dimension names of :attr:`self`. + + There are two main usages: + + ``self.rename(**rename_map)`` returns a view on tensor that has dims + renamed as specified in the mapping :attr:`rename_map`. + + ``self.rename(*names)`` returns a view on tensor, renaming all + dimensions positionally using :attr:`names`. + Use ``self.rename(None)`` to drop names on a tensor. + + One cannot specify both positional args :attr:`names` and keyword args + :attr:`rename_map`. + + Examples:: + + >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> renamed_imgs = imgs.rename(N='batch', C='channels') + >>> renamed_imgs.names + ('batch', 'channels', 'H', 'W') + + >>> renamed_imgs = imgs.rename(None) + >>> renamed_imgs.names + (None,) + + >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width') + >>> renamed_imgs.names + ('batch', 'channel', 'height', 'width') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ # See Note [rename_ / rename API] return update_names(self, names, rename_map, inplace=False) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 54848b88884..6181e9751cb 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import collections from .constants import *