mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Change our representation of sizes and strides to contain SymInts instead of int64_t. Right now it's not actually possible to create a Tensor with symbolic shape, so this change is intended to be a no-op. But the intended behavior is: - If you create a Tensor with symbolic shape, a `CustomSizes` policy will be set, and the `has_symbolic_sizes_strides_` bit will be set. (not currently implemented) - Calling any TensorImpl function that naively interacts with sizes and strides will throw. For hot-path functions (`sizes()`, `strides()`), we make use of the existing policy check to throw. For others, we just have a regular `TORCH_CHECK(!has_symbolic_sizes_strides_)`. This also undoes the explicit constructor I made in https://github.com/pytorch/pytorch/pull/77666; it ended up being more annoying than useful when making these changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78272 Approved by: https://github.com/Krovatkin, https://github.com/Chillee
27 lines
650 B
C++
27 lines
650 B
C++
#include <c10/core/SymIntArrayRef.h>
|
|
#include <iostream>
|
|
|
|
namespace c10 {
|
|
|
|
at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar) {
|
|
for (c10::SymInt sci : ar) {
|
|
TORCH_CHECK(!sci.is_symbolic());
|
|
}
|
|
return asIntArrayRefUnchecked(ar);
|
|
}
|
|
|
|
at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
|
|
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, SymInt s) {
|
|
os << "SymInt(" << s.data() << ")";
|
|
return os;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list) {
|
|
return out << list.wrapped_symint_array_ref;
|
|
}
|
|
|
|
} // namespace c10
|