mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
convert guard_size_oblivious to runtime check in infer_size_impl (#148872)
its ok to check the requirement numel == newsize at runtime in case of unbacked instead of at compile time and assume that its true. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148872 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
0cf61ca7e4
commit
8b507a9809
|
|
@ -25,6 +25,8 @@ inline void infer_size_impl(
|
|||
// N.B. this is an index, not a sym dim!
|
||||
std::optional<int64_t> infer_dim;
|
||||
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
||||
// We can avoid failing on unbacked shape[dim] and assert that it is >=0
|
||||
// following python behaviour.
|
||||
if (shape[dim] == -1) {
|
||||
if (infer_dim) {
|
||||
throw std::runtime_error("only one dimension can be inferred");
|
||||
|
|
@ -37,31 +39,39 @@ inline void infer_size_impl(
|
|||
}
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
|
||||
(infer_dim && newsize > 0 && numel % newsize == 0)) {
|
||||
if (infer_dim) {
|
||||
// We have a degree of freedom here to select the dimension size; follow
|
||||
// NumPy semantics and just bail. However, a nice error message is needed
|
||||
// because users often use `view` as a way to flatten & unflatten
|
||||
// dimensions and will otherwise be confused why
|
||||
// empty_tensor.view( 0, 0)
|
||||
// works yet
|
||||
// empty_tensor.view(-1, 0)
|
||||
// doesn't.
|
||||
TORCH_CHECK(
|
||||
newsize != 0,
|
||||
"cannot reshape tensor of 0 elements into shape ",
|
||||
shape,
|
||||
" because the unspecified dimension size -1 can be any "
|
||||
"value and is ambiguous");
|
||||
res[*infer_dim] = numel / newsize;
|
||||
}
|
||||
auto set_infer_dim = [&]() {
|
||||
// We have a degree of freedom here to select the dimension size; follow
|
||||
// NumPy semantics and just bail. However, a nice error message is needed
|
||||
// because users often use `view` as a way to flatten & unflatten
|
||||
// dimensions and will otherwise be confused why
|
||||
// empty_tensor.view( 0, 0)
|
||||
// works yet
|
||||
// empty_tensor.view(-1, 0)
|
||||
// doesn't.
|
||||
TORCH_CHECK(
|
||||
newsize != 0,
|
||||
"cannot reshape tensor of 0 elements into shape ",
|
||||
shape,
|
||||
" because the unspecified dimension size -1 can be any "
|
||||
"value and is ambiguous");
|
||||
res[*infer_dim] = numel / newsize;
|
||||
return;
|
||||
};
|
||||
|
||||
if (infer_dim && newsize > 0 && numel % newsize == 0) {
|
||||
set_infer_dim();
|
||||
return;
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "shape '" << shape << "' is invalid for input of size " << numel;
|
||||
throw std::runtime_error(ss.str());
|
||||
TORCH_MAYBE_SYM_CHECK(
|
||||
sym_eq(numel, newsize),
|
||||
"shape '",
|
||||
shape,
|
||||
"' is invalid for input of size ",
|
||||
numel);
|
||||
if (infer_dim) {
|
||||
set_infer_dim();
|
||||
}
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
|
||||
|
|
|
|||
|
|
@ -4198,7 +4198,7 @@ Tensor ravel(const Tensor& self) {
|
|||
}
|
||||
|
||||
static inline void handle_unflatten_exception(
|
||||
const std::runtime_error& e,
|
||||
const std::exception& e,
|
||||
const Tensor& self,
|
||||
int64_t dim,
|
||||
SymIntArrayRef sizes) {
|
||||
|
|
@ -4251,7 +4251,7 @@ static Tensor unflatten_impl(
|
|||
SymDimVector inferred_size;
|
||||
try {
|
||||
inferred_size = at::infer_size_dv(sizes, self.sym_size(dim));
|
||||
} catch (const std::runtime_error& e) {
|
||||
} catch (const std::exception& e) {
|
||||
// at::infer_size would throw std::runtime_error for invalid size,
|
||||
// catch the runtime_error and display the error message in a more
|
||||
// user-friendly way for both tensors and named tensors
|
||||
|
|
|
|||
|
|
@ -7072,7 +7072,7 @@ class TestTorch(TestCase):
|
|||
torch.tensor([1]).unflatten(0, [])
|
||||
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
|
||||
torch.tensor([1]).unflatten(0, [2, 2])
|
||||
with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
|
||||
with self.assertRaisesRegex(RuntimeError, r".*Dimension specified as 0 but tensor has no dimensions.*"):
|
||||
torch.tensor(1).unflatten(0, [0])
|
||||
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
|
||||
torch.randn(5, 10).unflatten(1, (-1, -1))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user