[Reland] [5/N] Change static functions in headers to inline (#131010)

Reland of #130673

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131010
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2024-07-18 15:53:48 +00:00 committed by PyTorch MergeBot
parent d6ae8bbf16
commit 7c90a82970
14 changed files with 110 additions and 110 deletions

View File

@ -432,73 +432,73 @@ class TORCH_API Context {
TORCH_API Context& globalContext(); TORCH_API Context& globalContext();
static inline void init() { inline void init() {
globalContext(); globalContext();
} }
TORCH_API Allocator* getCPUAllocator(); TORCH_API Allocator* getCPUAllocator();
static inline DeprecatedTypeProperties& getDeprecatedTypeProperties( inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
Backend p, Backend p,
ScalarType s) { ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
p, s); p, s);
} }
static inline DeprecatedTypeProperties& CPU(ScalarType s) { inline DeprecatedTypeProperties& CPU(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::CPU, s); Backend::CPU, s);
} }
static inline DeprecatedTypeProperties& CUDA(ScalarType s) { inline DeprecatedTypeProperties& CUDA(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::CUDA, s); Backend::CUDA, s);
} }
static inline DeprecatedTypeProperties& HIP(ScalarType s) { inline DeprecatedTypeProperties& HIP(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::HIP, s); Backend::HIP, s);
} }
static inline DeprecatedTypeProperties& MPS(ScalarType s) { inline DeprecatedTypeProperties& MPS(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::MPS, s); Backend::MPS, s);
} }
static inline bool hasCUDA() { inline bool hasCUDA() {
return globalContext().hasCUDA(); return globalContext().hasCUDA();
} }
static inline bool hasMTIA() { inline bool hasMTIA() {
return globalContext().hasMTIA(); return globalContext().hasMTIA();
} }
static inline bool hasHIP() { inline bool hasHIP() {
return globalContext().hasHIP(); return globalContext().hasHIP();
} }
static inline bool hasIPU() { inline bool hasIPU() {
return globalContext().hasIPU(); return globalContext().hasIPU();
} }
static inline bool hasXLA() { inline bool hasXLA() {
return globalContext().hasXLA(); return globalContext().hasXLA();
} }
static inline bool hasMPS() { inline bool hasMPS() {
return globalContext().hasMPS(); return globalContext().hasMPS();
} }
static inline bool hasMAIA() { inline bool hasMAIA() {
return globalContext().hasMAIA(); return globalContext().hasMAIA();
} }
static inline bool hasXPU() { inline bool hasXPU() {
return globalContext().hasXPU(); return globalContext().hasXPU();
} }
// Despite its name, this function returns the number of *CUDA* GPUs. // Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() { inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
// FUNCTION. If you are interested in interrogating the number of // FUNCTION. If you are interested in interrogating the number of
// devices for a specific device type, add that function to the // devices for a specific device type, add that function to the
@ -517,27 +517,27 @@ static inline size_t getNumGPUs() {
} }
} }
static inline bool hasOpenMP() { inline bool hasOpenMP() {
return globalContext().hasOpenMP(); return globalContext().hasOpenMP();
} }
static inline bool hasMKL() { inline bool hasMKL() {
return globalContext().hasMKL(); return globalContext().hasMKL();
} }
static inline bool hasLAPACK() { inline bool hasLAPACK() {
return globalContext().hasLAPACK(); return globalContext().hasLAPACK();
} }
static inline bool hasMAGMA() { inline bool hasMAGMA() {
return globalContext().hasMAGMA(); return globalContext().hasMAGMA();
} }
static inline bool hasMKLDNN() { inline bool hasMKLDNN() {
return globalContext().hasMKLDNN(); return globalContext().hasMKLDNN();
} }
static inline void manual_seed(uint64_t seed) { inline void manual_seed(uint64_t seed) {
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU); auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]

View File

@ -499,7 +499,7 @@ inline Tensor sum_to(
return _sum_to(std::move(tensor), shape, always_return_non_view); return _sum_to(std::move(tensor), shape, always_return_non_view);
} }
static inline bool is_expandable_to( inline bool is_expandable_to(
SymIntArrayRef shape, SymIntArrayRef shape,
c10::SymIntArrayRef desired) { c10::SymIntArrayRef desired) {
size_t ndim = shape.size(); size_t ndim = shape.size();
@ -517,7 +517,7 @@ static inline bool is_expandable_to(
return true; return true;
} }
static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
auto sym_shape = c10::SymIntArrayRef( auto sym_shape = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size()); reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
auto sym_desired = c10::SymIntArrayRef( auto sym_desired = c10::SymIntArrayRef(

View File

@ -33,15 +33,15 @@ namespace at {
_(==, x.eq(y), y.eq(x)) \ _(==, x.eq(y), y.eq(x)) \
_(!=, x.ne(y), y.ne(x)) _(!=, x.ne(y), y.ne(x))
#define DEFINE_OPERATOR(op, body, reverse_scalar_body) \ #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
static inline Tensor operator op(const Tensor& x, const Tensor& y) { \ inline Tensor operator op(const Tensor& x, const Tensor& y) { \
return body; \ return body; \
} \ } \
static inline Tensor operator op(const Tensor& x, const Scalar& y) { \ inline Tensor operator op(const Tensor& x, const Scalar& y) { \
return body; \ return body; \
} \ } \
static inline Tensor operator op(const Scalar& x, const Tensor& y) { \ inline Tensor operator op(const Scalar& x, const Tensor& y) { \
return reverse_scalar_body; \ return reverse_scalar_body; \
} }
AT_FORALL_BINARY_OPS(DEFINE_OPERATOR) AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)

View File

@ -113,12 +113,12 @@
namespace at::tracer::impl { namespace at::tracer::impl {
static inline bool is_dispatch_enabled() { inline bool is_dispatch_enabled() {
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) && return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer); !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
} }
static inline void set_dispatch_enabled(bool enabled) { inline void set_dispatch_enabled(bool enabled) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer), !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
"Cannot enable tracing within the scope of NoTracerDispatchMode!"); "Cannot enable tracing within the scope of NoTracerDispatchMode!");

View File

@ -29,7 +29,7 @@ TORCH_API int _crash_if_asan(int);
// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*) // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat. // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
// Once cat is ported entirely to ATen this can be deleted! // Once cat is ported entirely to ATen this can be deleted!
static inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap( inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
ArrayRef<Tensor> tensors, ArrayRef<Tensor> tensors,
const char* name, const char* name,
int pos, int pos,

View File

@ -13,7 +13,7 @@ namespace at {
constexpr size_t dim_bitset_size = 64; constexpr size_t dim_bitset_size = 64;
static inline std::bitset<dim_bitset_size> dim_list_to_bitset( inline std::bitset<dim_bitset_size> dim_list_to_bitset(
OptionalIntArrayRef opt_dims, OptionalIntArrayRef opt_dims,
size_t ndims) { size_t ndims) {
TORCH_CHECK( TORCH_CHECK(

View File

@ -18,7 +18,7 @@ TORCH_API std::ostream& print(
std::ostream& stream, std::ostream& stream,
const Tensor& tensor, const Tensor& tensor,
int64_t linesize); int64_t linesize);
static inline std::ostream& operator<<(std::ostream & out, const Tensor & t) { inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
return print(out,t,80); return print(out,t,80);
} }
TORCH_API void print(const Tensor & t, int64_t linesize=80); TORCH_API void print(const Tensor & t, int64_t linesize=80);

View File

@ -93,7 +93,7 @@ torch::jit::Stack boxArgs(Args... args) {
} }
template <class T> template <class T>
static inline constexpr size_t boxed_size_one() { inline constexpr size_t boxed_size_one() {
static_assert(!std::is_same<std::decay_t<T>, c10::TensorOptions>::value, "need to patch this path to support TensorOptions passed by reference"); static_assert(!std::is_same<std::decay_t<T>, c10::TensorOptions>::value, "need to patch this path to support TensorOptions passed by reference");
return 1; return 1;
} }

View File

@ -159,7 +159,7 @@ struct Atomic##NAME##IntegerImpl<T, 8> {
# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \ # define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \ inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \ Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
val, \ val, \
[](DTYPE a, DTYPE b) { \ [](DTYPE a, DTYPE b) { \
@ -171,7 +171,7 @@ ATOMIC_INTEGER_IMPL(Add)
GPU_ATOMIC_INTEGER(Add, a || b, bool) GPU_ATOMIC_INTEGER(Add, a || b, bool)
// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64) // Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) { inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
val, val,
[](uint8_t a, uint8_t b) { [](uint8_t a, uint8_t b) {
@ -179,7 +179,7 @@ static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
}); });
} }
static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) { inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
val, val,
[](int8_t a, int8_t b) { [](int8_t a, int8_t b) {
@ -187,7 +187,7 @@ static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
}); });
} }
static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) { inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
val, val,
[](int16_t a, int16_t b) { [](int16_t a, int16_t b) {
@ -195,11 +195,11 @@ static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
}); });
} }
static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) { inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
return atomicAdd(address, val); return atomicAdd(address, val);
} }
static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) { inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#if defined(USE_ROCM) #if defined(USE_ROCM)
__atomic_fetch_add(address, val, __ATOMIC_RELAXED); __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
#else #else
@ -208,7 +208,7 @@ static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#endif #endif
} }
static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) { inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
return AtomicFPOp<at::Half>()(address, val, return AtomicFPOp<at::Half>()(address, val,
[](at::Half hsum, at::Half val) { [](at::Half hsum, at::Half val) {
@ -219,7 +219,7 @@ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val)
#endif #endif
} }
static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
return AtomicFPOp<at::BFloat16>()(address, val, return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) { [](at::BFloat16 bsum, at::BFloat16 val) {
@ -233,7 +233,7 @@ return AtomicFPOp<at::BFloat16>()(address, val,
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
// from CUDA C Programmic Guide // from CUDA C Programmic Guide
static inline __device__ double atomicAdd(double* address, double val) inline __device__ double atomicAdd(double* address, double val)
#if defined(__clang__) && defined(__CUDA__) #if defined(__clang__) && defined(__CUDA__)
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wgcc-compat" #pragma GCC diagnostic ignored "-Wgcc-compat"
@ -261,20 +261,20 @@ static inline __device__ double atomicAdd(double* address, double val)
#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__ #if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
// This needs to be defined for the host side pass // This needs to be defined for the host side pass
static inline __device__ double atomicAdd(double *address, double val) { } inline __device__ double atomicAdd(double *address, double val) { }
#endif #endif
#endif #endif
static inline __device__ double gpuAtomicAdd(double *address, double val) { inline __device__ double gpuAtomicAdd(double *address, double val) {
return atomicAdd(address, val); return atomicAdd(address, val);
} }
static inline __device__ float gpuAtomicAdd(float *address, float val) { inline __device__ float gpuAtomicAdd(float *address, float val) {
return atomicAdd(address, val); return atomicAdd(address, val);
} }
template<typename T> template<typename T>
static inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) { inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
gpuAtomicAdd(&address->real_, val.real_); gpuAtomicAdd(&address->real_, val.real_);
gpuAtomicAdd(&address->imag_, val.imag_); gpuAtomicAdd(&address->imag_, val.imag_);
} }
@ -285,31 +285,31 @@ static inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::comple
* directly and require non-library provided data type support. Only for these, we * directly and require non-library provided data type support. Only for these, we
* continue to provide atomicAdd overloads. * continue to provide atomicAdd overloads.
*/ */
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
return gpuAtomicAdd(address, val); return gpuAtomicAdd(address, val);
} }
static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) { inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
return gpuAtomicAdd(address, val); return gpuAtomicAdd(address, val);
} }
static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
gpuAtomicAdd(address, val); gpuAtomicAdd(address, val);
} }
static inline __device__ void atomicAdd(int8_t *address, int8_t val) { inline __device__ void atomicAdd(int8_t *address, int8_t val) {
gpuAtomicAdd(address, val); gpuAtomicAdd(address, val);
} }
static inline __device__ void atomicAdd(int16_t *address, int16_t val) { inline __device__ void atomicAdd(int16_t *address, int16_t val) {
gpuAtomicAdd(address, val); gpuAtomicAdd(address, val);
} }
static inline __device__ void atomicAdd(int64_t *address, int64_t val) { inline __device__ void atomicAdd(int64_t *address, int64_t val) {
gpuAtomicAdd(address, val); gpuAtomicAdd(address, val);
} }
static inline __device__ void atomicAdd(bool *address, bool val) { inline __device__ void atomicAdd(bool *address, bool val) {
gpuAtomicAdd(address, val); gpuAtomicAdd(address, val);
} }
@ -321,20 +321,20 @@ static inline __device__ void atomicAdd(bool *address, bool val) {
* therefore we need a new API 'gpuAtomicAddNoReturn'. * therefore we need a new API 'gpuAtomicAddNoReturn'.
*/ */
template<typename T> template<typename T>
static inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
/* Special case fp32 atomic. */ /* Special case fp32 atomic. */
#if defined(USE_ROCM) #if defined(USE_ROCM)
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
#if defined(__gfx908__) #if defined(__gfx908__)
atomicAddNoRet(address, val); atomicAddNoRet(address, val);
#else #else
@ -342,7 +342,7 @@ static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
#endif #endif
} }
#else #else
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); } inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
#endif #endif
// Atomic multiplication implementation. // Atomic multiplication implementation.

View File

@ -11,7 +11,7 @@ namespace {
// checks whether index.dtype == int64 // checks whether index.dtype == int64
// and self.dtype == src.dtype if src is a Tensor // and self.dtype == src.dtype if src is a Tensor
static void scatter_gather_dtype_check( inline void scatter_gather_dtype_check(
const std::string& method_name, const std::string& method_name,
const Tensor& self, const Tensor& self,
const Tensor& index, const Tensor& index,
@ -38,7 +38,7 @@ static void scatter_gather_dtype_check(
// Test: // Test:
// 1. index.size(d) <= self.size(d) for all d != dim // 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.dim() == self.dim() // 2. index.dim() == self.dim()
static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim, inline void gather_shape_check(const Tensor& self, int64_t dim,
const Tensor& index const Tensor& index
) { ) {
auto self_dims = ensure_nonempty_dim(self.dim()); auto self_dims = ensure_nonempty_dim(self.dim());
@ -64,7 +64,7 @@ static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
// 1. index.size(d) <= self.size(d) for all d != dim // 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
// 3. index.dim() == self.dim() == src.dim() // 3. index.dim() == self.dim() == src.dim()
static C10_UNUSED void scatter_shape_check( inline void scatter_shape_check(
const Tensor& self, int64_t dim, const Tensor& index, const Tensor& self, int64_t dim, const Tensor& index,
const std::optional<Tensor>& src_opt = std::nullopt const std::optional<Tensor>& src_opt = std::nullopt
) { ) {

View File

@ -18,7 +18,7 @@ using detail::GridSamplerPadding;
// +1 --> (size - 1) + 0.5 == size - 0.5 // +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2 // scale_factor = size / 2
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
if (align_corners) { if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1] // unnormalize coord from [-1, 1] to [0, size - 1]
@ -34,7 +34,7 @@ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners)
// `grad_in`. // `grad_in`.
// This is useful in the backward pass of grid_sampler. // This is useful in the backward pass of grid_sampler.
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
bool align_corners, scalar_t *grad_in) { bool align_corners, scalar_t *grad_in) {
if (align_corners) { if (align_corners) {
@ -50,7 +50,7 @@ scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
// Clips coordinates to between 0 and clip_limit - 1 // Clips coordinates to between 0 and clip_limit - 1
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t clip_coordinates(scalar_t in, int clip_limit) { scalar_t clip_coordinates(scalar_t in, int clip_limit) {
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0))); return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
} }
@ -59,7 +59,7 @@ scalar_t clip_coordinates(scalar_t in, int clip_limit) {
// it also returns the `d output / d input` via pointer argument `grad_in`. // it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler. // This is useful in the backward pass of grid_sampler.
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
// Note that it is important for the gradient calculation that borders // Note that it is important for the gradient calculation that borders
// are considered out of bounds. // are considered out of bounds.
@ -82,7 +82,7 @@ scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_i
// The bounds are passed as twice their value so that half-integer values // The bounds are passed as twice their value so that half-integer values
// can be represented as ints. // can be represented as ints.
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
if (twice_low == twice_high) { if (twice_low == twice_high) {
return static_cast<scalar_t>(0); return static_cast<scalar_t>(0);
@ -105,7 +105,7 @@ scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
// `grad_in`. // `grad_in`.
// This is useful in the backward pass of grid_sampler. // This is useful in the backward pass of grid_sampler.
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high, scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
scalar_t *grad_in) { scalar_t *grad_in) {
if (twice_low == twice_high) { if (twice_low == twice_high) {
@ -135,7 +135,7 @@ scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high
} }
template<typename scalar_t> template<typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t safe_downgrade_to_int_range(scalar_t x){ scalar_t safe_downgrade_to_int_range(scalar_t x){
// -100.0 does not have special meaning. This is just to make sure // -100.0 does not have special meaning. This is just to make sure
// it's not within_bounds_2d or within_bounds_3d, and does not cause // it's not within_bounds_2d or within_bounds_3d, and does not cause
@ -146,7 +146,7 @@ scalar_t safe_downgrade_to_int_range(scalar_t x){
} }
template<typename scalar_t> template<typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t compute_coordinates(scalar_t coord, int size, scalar_t compute_coordinates(scalar_t coord, int size,
GridSamplerPadding padding_mode, GridSamplerPadding padding_mode,
bool align_corners) { bool align_corners) {
@ -170,7 +170,7 @@ scalar_t compute_coordinates(scalar_t coord, int size,
// Computes the pixel source index value for a grid coordinate // Computes the pixel source index value for a grid coordinate
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t grid_sampler_compute_source_index( scalar_t grid_sampler_compute_source_index(
scalar_t coord, scalar_t coord,
int size, int size,
@ -186,7 +186,7 @@ scalar_t grid_sampler_compute_source_index(
// `d output / d input` via pointer argument `grad_in`. // `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler. // This is useful in the backward pass of grid_sampler.
template <typename scalar_t> template <typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t grid_sampler_compute_source_index_set_grad( scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord, scalar_t coord,
int size, int size,
@ -215,18 +215,18 @@ scalar_t grid_sampler_compute_source_index_set_grad(
return coord; return coord;
} }
static __forceinline__ __device__ __forceinline__ __device__
bool within_bounds_2d(int h, int w, int H, int W) { bool within_bounds_2d(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W; return h >= 0 && h < H && w >= 0 && w < W;
} }
static __forceinline__ __device__ __forceinline__ __device__
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
} }
template<typename scalar_t> template<typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
scalar_t get_value_bounded( scalar_t get_value_bounded(
const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
GridSamplerPadding padding_mode, GridSamplerPadding padding_mode,
@ -245,7 +245,7 @@ scalar_t get_value_bounded(
} }
template<typename scalar_t, typename index_t> template<typename scalar_t, typename index_t>
static __forceinline__ __device__ __forceinline__ __device__
void safe_add_2d(scalar_t *data, int h, int w, void safe_add_2d(scalar_t *data, int h, int w,
int sH, int sW, int H, int W, int sH, int sW, int H, int W,
scalar_t delta, scalar_t delta,
@ -261,7 +261,7 @@ void safe_add_2d(scalar_t *data, int h, int w,
} }
template<typename scalar_t, typename index_t> template<typename scalar_t, typename index_t>
static __forceinline__ __device__ __forceinline__ __device__
void safe_add_3d(scalar_t *data, int d, int h, int w, void safe_add_3d(scalar_t *data, int d, int h, int w,
int sD, int sH, int sW, int D, int H, int W, int sD, int sH, int sW, int D, int H, int W,
scalar_t delta, scalar_t delta,
@ -277,7 +277,7 @@ void safe_add_3d(scalar_t *data, int d, int h, int w,
} }
template<typename scalar_t, typename index_t> template<typename scalar_t, typename index_t>
static __forceinline__ __device__ __forceinline__ __device__
void add_value_bounded( void add_value_bounded(
scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
scalar_t delta, scalar_t delta,
@ -297,7 +297,7 @@ void add_value_bounded(
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template<typename scalar_t> template<typename scalar_t>
static __forceinline__ __device__ __forceinline__ __device__
void get_cubic_coefficients_grad( void get_cubic_coefficients_grad(
scalar_t coeffs[4], scalar_t coeffs[4],
scalar_t t) { scalar_t t) {

View File

@ -21,7 +21,7 @@ constexpr int MAX_BLOCK_SIZE = 1024;
// Maximum size per grid dimension that we assume (compute capability >= 2.0) // Maximum size per grid dimension that we assume (compute capability >= 2.0)
constexpr int64_t MAX_GRID_SIZE = 65535LL; constexpr int64_t MAX_GRID_SIZE = 65535LL;
static bool getGridFromTiles(int64_t gridTiles, dim3& grid) { inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
return false; return false;
} }
@ -92,7 +92,7 @@ struct GlobalIndexToPerSliceIndex {
}; };
// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks // Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
static uint64_t nextHighestPowerOf2(uint64_t n) { inline uint64_t nextHighestPowerOf2(uint64_t n) {
n--; n--;
n |= n >> 1; n |= n >> 1;
n |= n >> 2; n |= n >> 2;

View File

@ -72,7 +72,7 @@ __device__ inline scalar_t max(scalar_t a, scalar_t b) {
// see NOTE [ Nearest neighbor upsampling kernel implementation ] // see NOTE [ Nearest neighbor upsampling kernel implementation ]
template <typename accscalar_t> template <typename accscalar_t>
__host__ __forceinline__ static accscalar_t compute_scales_value( __host__ __forceinline__ accscalar_t compute_scales_value(
const std::optional<double> scale, const std::optional<double> scale,
int64_t src_size, int64_t src_size,
int64_t dst_size) { int64_t dst_size) {
@ -83,7 +83,7 @@ __host__ __forceinline__ static accscalar_t compute_scales_value(
// see NOTE [ Nearest neighbor upsampling kernel implementation ] // see NOTE [ Nearest neighbor upsampling kernel implementation ]
template <typename accscalar_t> template <typename accscalar_t>
__host__ __forceinline__ static accscalar_t compute_scales_value_backwards( __host__ __forceinline__ accscalar_t compute_scales_value_backwards(
const std::optional<double> scale, const std::optional<double> scale,
int64_t src_size, int64_t src_size,
int64_t dst_size) { int64_t dst_size) {
@ -93,7 +93,7 @@ __host__ __forceinline__ static accscalar_t compute_scales_value_backwards(
} }
template <typename accscalar_t> template <typename accscalar_t>
__host__ __forceinline__ static accscalar_t area_pixel_compute_scale( __host__ __forceinline__ accscalar_t area_pixel_compute_scale(
int input_size, int input_size,
int output_size, int output_size,
bool align_corners, bool align_corners,
@ -112,7 +112,7 @@ __host__ __forceinline__ static accscalar_t area_pixel_compute_scale(
} }
template <typename accscalar_t> template <typename accscalar_t>
__device__ __forceinline__ static accscalar_t area_pixel_compute_source_index( __device__ __forceinline__ accscalar_t area_pixel_compute_source_index(
accscalar_t scale, accscalar_t scale,
int dst_index, int dst_index,
bool align_corners, bool align_corners,
@ -130,7 +130,7 @@ __device__ __forceinline__ static accscalar_t area_pixel_compute_source_index(
} }
// see NOTE [ Nearest neighbor upsampling kernel implementation ] // see NOTE [ Nearest neighbor upsampling kernel implementation ]
__device__ __forceinline__ static int nearest_neighbor_compute_source_index( __device__ __forceinline__ int nearest_neighbor_compute_source_index(
const float scale, const float scale,
int dst_index, int dst_index,
int input_size) { int input_size) {
@ -144,7 +144,7 @@ __device__ __forceinline__ static int nearest_neighbor_compute_source_index(
return src_index; return src_index;
} }
__device__ __forceinline__ static int nearest_neighbor_exact_compute_source_index( __device__ __forceinline__ int nearest_neighbor_exact_compute_source_index(
const float scale, const float scale,
int dst_index, int dst_index,
int input_size) { int input_size) {
@ -157,7 +157,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_compute_source_inde
} }
// see NOTE [ Nearest neighbor upsampling kernel implementation ] // see NOTE [ Nearest neighbor upsampling kernel implementation ]
__device__ __forceinline__ static int nearest_neighbor_bw_compute_source_index( __device__ __forceinline__ int nearest_neighbor_bw_compute_source_index(
const float scale, const float scale,
int dst_index, int dst_index,
int output_size) { int output_size) {
@ -170,7 +170,7 @@ __device__ __forceinline__ static int nearest_neighbor_bw_compute_source_index(
} }
// see NOTE [ Nearest neighbor upsampling kernel implementation ] // see NOTE [ Nearest neighbor upsampling kernel implementation ]
__device__ __forceinline__ static int nearest_neighbor_exact_bw_compute_source_index( __device__ __forceinline__ int nearest_neighbor_exact_bw_compute_source_index(
const float scale, const float scale,
int dst_index, int dst_index,
int output_size) { int output_size) {
@ -182,7 +182,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_bw_compute_source_i
/* Used by UpSampleBicubic2d.cu */ /* Used by UpSampleBicubic2d.cu */
template <typename scalar_t> template <typename scalar_t>
__device__ __forceinline__ static scalar_t upsample_get_value_bounded( __device__ __forceinline__ scalar_t upsample_get_value_bounded(
const PackedTensorAccessor64<const scalar_t, 4>& data, const PackedTensorAccessor64<const scalar_t, 4>& data,
int batch, int batch,
int channel, int channel,
@ -197,7 +197,7 @@ __device__ __forceinline__ static scalar_t upsample_get_value_bounded(
/* Used by UpSampleBicubic2d.cu */ /* Used by UpSampleBicubic2d.cu */
template <typename scalar_t, typename accscalar_t> template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static void upsample_increment_value_bounded( __device__ __forceinline__ void upsample_increment_value_bounded(
PackedTensorAccessor64<scalar_t, 4>& data, PackedTensorAccessor64<scalar_t, 4>& data,
int batch, int batch,
int channel, int channel,
@ -218,21 +218,21 @@ __device__ __forceinline__ static void upsample_increment_value_bounded(
// Based on // Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename accscalar_t> template <typename accscalar_t>
__device__ __forceinline__ static accscalar_t cubic_convolution1( __device__ __forceinline__ accscalar_t cubic_convolution1(
accscalar_t x, accscalar_t x,
accscalar_t A) { accscalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1; return ((A + 2) * x - (A + 3)) * x * x + 1;
} }
template <typename accscalar_t> template <typename accscalar_t>
__device__ __forceinline__ static accscalar_t cubic_convolution2( __device__ __forceinline__ accscalar_t cubic_convolution2(
accscalar_t x, accscalar_t x,
accscalar_t A) { accscalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
} }
template <typename accscalar_t> template <typename accscalar_t>
__device__ __forceinline__ static void get_cubic_upsampling_coefficients( __device__ __forceinline__ void get_cubic_upsampling_coefficients(
accscalar_t coeffs[4], accscalar_t coeffs[4],
accscalar_t t) { accscalar_t t) {
accscalar_t A = -0.75; accscalar_t A = -0.75;
@ -248,7 +248,7 @@ __device__ __forceinline__ static void get_cubic_upsampling_coefficients(
} }
template <typename scalar_t, typename accscalar_t> template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static accscalar_t cubic_interp1d( __device__ __forceinline__ accscalar_t cubic_interp1d(
scalar_t x0, scalar_t x0,
scalar_t x1, scalar_t x1,
scalar_t x2, scalar_t x2,
@ -306,7 +306,7 @@ struct BicubicFilterFunctor {
}; };
template <typename accscalar_t> template <typename accscalar_t>
__device__ __forceinline__ static void _compute_weights_span( __device__ __forceinline__ void _compute_weights_span(
const int i, const int i,
const int input_size, const int input_size,
const accscalar_t scale, const accscalar_t scale,
@ -320,7 +320,7 @@ __device__ __forceinline__ static void _compute_weights_span(
} }
template <typename scalar_t, typename accscalar_t, typename interp_filter_t> template <typename scalar_t, typename accscalar_t, typename interp_filter_t>
__device__ __forceinline__ static void _compute_weights( __device__ __forceinline__ void _compute_weights(
scalar_t* wt_ptr, scalar_t* wt_ptr,
const accscalar_t scale, const accscalar_t scale,
int interp_size, int interp_size,
@ -347,7 +347,7 @@ __device__ __forceinline__ static void _compute_weights(
} }
template <typename scalar_t, typename accscalar_t> template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim( __device__ __forceinline__ accscalar_t interpolate_aa_single_dim(
const scalar_t* src, const scalar_t* src,
const scalar_t* weights, const scalar_t* weights,
int size) { int size) {

View File

@ -14,7 +14,7 @@
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
#define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE #define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE
#else /* __CUDACC_RTC__ */ #else /* __CUDACC_RTC__ */
#define __MATH_FUNCTIONS_DECL__ static inline C10_HOST_DEVICE #define __MATH_FUNCTIONS_DECL__ inline C10_HOST_DEVICE
#endif /* __CUDACC_RTC__ */ #endif /* __CUDACC_RTC__ */
#endif /* __HIPCC__ */ #endif /* __HIPCC__ */