pytorch/c10/core/SymIntArrayRef.cpp
Nikolay Korovaiko d2c47d559c Revert "Revert "Enabling SymInt in autograd; take 3 (#81145)"" ; make sure is_intlist checks for symintnodes (#82189)
### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82189
Approved by: https://github.com/ezyang
2022-07-26 20:47:11 +00:00

35 lines
815 B
C++

#include <c10/core/SymIntArrayRef.h>
#include <c10/util/Optional.h>
#include <iostream>
namespace c10 {
at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar) {
auto r = asIntArrayRefSlowOpt(ar);
TORCH_CHECK(
r.has_value(),
"SymIntArrayRef expected to contain only concrete integers");
return *r;
}
c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(c10::SymIntArrayRef ar) {
for (c10::SymInt sci : ar) {
if (sci.is_symbolic()) {
return c10::nullopt;
}
}
return {asIntArrayRefUnchecked(ar)};
}
at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
}
std::ostream& operator<<(std::ostream& os, SymInt s) {
os << "SymInt(" << s.data() << ")";
return os;
}
} // namespace c10