pytorch/c10/core/SymIntArrayRef.cpp
Michael Suo 49979c4021 [symint] Make TensorImpl::sizes_and_strides_ contain SymInt
Change our representation of sizes and strides to contain SymInts
instead of int64_t.

Right now it's not actually possible to create a Tensor with symbolic
shape, so this change is intended to be a no-op.

But the intended behavior is:
- If you create a Tensor with symbolic shape, a `CustomSizes` policy
will be set, and the `has_symbolic_sizes_strides_` bit will be set. (not
currently implemented)
- Calling any TensorImpl function that naively interacts with sizes and
strides will throw. For hot-path functions (`sizes()`, `strides()`), we
make use of the existing policy check to throw. For others, we just have
a regular `TORCH_CHECK(!has_symbolic_sizes_strides_)`.

This also undoes the explicit constructor I made in
https://github.com/pytorch/pytorch/pull/77666; it ended up being more
annoying than useful when making these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78272

Approved by: https://github.com/Krovatkin, https://github.com/Chillee
2022-05-25 20:54:51 +00:00

27 lines
650 B
C++

#include <c10/core/SymIntArrayRef.h>
#include <iostream>
namespace c10 {
at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar) {
for (c10::SymInt sci : ar) {
TORCH_CHECK(!sci.is_symbolic());
}
return asIntArrayRefUnchecked(ar);
}
at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
}
std::ostream& operator<<(std::ostream& os, SymInt s) {
os << "SymInt(" << s.data() << ")";
return os;
}
std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list) {
return out << list.wrapped_symint_array_ref;
}
} // namespace c10