mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This improves `c10::maybe_wrap_dim` to short-cut the "happy path" where dim is in the correct range, and also moves the error and scalar edge-cases out-of-line. These changes cut callgrind instruction counts for `size(i)` from 5200 to 2000. In the `size` and `stride` methods themselves, I also avoid calling `TensorImpl::dim()` since it may be a virtual call. This further reduced the instruction count from 2000 to 1500. For comparison, `tensor.sizes()[0]` takes 1200 instructions so `tensor.size(0)` is still marginally slower. This is unavoidable though since it has to handle dimension wrapping. Pull Request resolved: https://github.com/pytorch/pytorch/pull/75416 Approved by: https://github.com/Lezcano, https://github.com/ngimel
26 lines
673 B
C++
26 lines
673 B
C++
#pragma once
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace c10 {
|
|
|
|
namespace detail {
|
|
C10_API int64_t
|
|
maybe_wrap_dim_slow(int64_t dim, int64_t dim_post_expr, bool wrap_scalar);
|
|
}
|
|
|
|
static inline int64_t maybe_wrap_dim(
|
|
int64_t dim,
|
|
int64_t dim_post_expr,
|
|
bool wrap_scalar = true) {
|
|
// Inline the fast paths
|
|
if (C10_LIKELY(-dim_post_expr <= dim && dim < dim_post_expr)) {
|
|
// Branch-less version of dim + (dim < 0 ? dim_post_expr : 0)
|
|
return dim + dim_post_expr * (dim < 0);
|
|
}
|
|
// Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors)
|
|
return c10::detail::maybe_wrap_dim_slow(dim, dim_post_expr, wrap_scalar);
|
|
}
|
|
|
|
} // namespace c10
|