mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix broken indexing when using None and ellipses indexing together (#22905)
Summary: https://github.com/pytorch/pytorch/issues/20153 I believe you need 2 passes for this. Take this example ```python torch.jit.script def f(): x = torch.ones(10, 9, 8, 7, 6) return x[..., None, None].shape ``` which results in `[10, 9, 8, 7, 6, 1, 1]` vs ``` torch.jit.script def f(): x = torch.ones(10, 9, 8, 7, 6) return x[..., None, None, :].shape ``` which results in `[10, 9, 8, 7, 1, 1, 6]` After only processing `x[..., None, None` we don't know whether we should be creating a new dimension at the end of the dimension list or somewhere in the middle. What we do depends on the elements to the right of it. Thus, I do 2 passes - one to collect all the dimensions that the index operations operate on, and another that executes the index operations. This still doesn't work for an ellipse index followed by a tensor index, but it wasn't working previously either. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22905 Differential Revision: D16433558 Pulled By: Chillee fbshipit-source-id: c1b303cb97b1af8b6e405bad33495ef3b4c27c4a
This commit is contained in:
parent
648f10be16
commit
a24f6c13a3
|
|
@ -3761,6 +3761,25 @@ a")
|
|||
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
|
||||
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
|
||||
|
||||
def test_index_ellipses(self):
|
||||
vals = [":", 1, None]
|
||||
for _ in range(100):
|
||||
indices = [random.choice(vals) for _ in range(4)]
|
||||
indices[random.randint(0, len(indices) - 1)] = "..."
|
||||
test_str = dedent("""
|
||||
def f():
|
||||
x = torch.ones(10, 9, 8, 7, 6)
|
||||
return x{indices}.shape
|
||||
""".format(indices=indices))
|
||||
test_str = test_str.replace(r"'", r'')
|
||||
scope = {}
|
||||
execWrapper(test_str, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(test_str)
|
||||
res1 = cu.f()
|
||||
res2 = scope['f']()
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
|
||||
def test_tensor_item(self):
|
||||
def test_scalar_cast(x):
|
||||
scalar = x.item()
|
||||
|
|
|
|||
|
|
@ -2617,13 +2617,13 @@ struct to_ir {
|
|||
loc, *graph, aten::slice, c10::nullopt, args, {step_nv}, true);
|
||||
}
|
||||
|
||||
Value* emitUnsqueeze(const SourceRange& loc, Value* input, int64_t dim) {
|
||||
Value* emitUnsqueeze(const SourceRange& loc, Value* input, Value* dim_val) {
|
||||
return emitBuiltinCall(
|
||||
loc,
|
||||
*graph,
|
||||
aten::unsqueeze,
|
||||
c10::nullopt,
|
||||
{input, graph->insertConstant(dim, nullptr, loc)},
|
||||
{input, dim_val},
|
||||
{},
|
||||
true);
|
||||
}
|
||||
|
|
@ -2653,71 +2653,119 @@ struct to_ir {
|
|||
const SourceRange& loc,
|
||||
Value* sliceable,
|
||||
const List<Expr>& subscript_exprs) {
|
||||
// Overall, to handle indexing (other than Tensors), we need to handle a couple different things.
|
||||
// For example, for x[1:3, None, 4], each of these different index types
|
||||
// (slice, None, and integer) result in different number of dimensions.
|
||||
// Slicing doesn't change the number of dimensions, None adds a dimension,
|
||||
// and integer removes a dimension. As these indexing operations are applied
|
||||
// left to right, the actual index that it's being applied to depends on the
|
||||
// previous operations.
|
||||
// Ellipses indexing throws another wrinkle. Ellipses selects any remaining
|
||||
// unspecified dimensions. Thus, for indexes following an ellipses, the
|
||||
// actual index an indexing operation is being applied to depends on the
|
||||
// operations to the right.
|
||||
// Thus, we do two passes, one from left to right up until the ellipses, and
|
||||
// one from right to left.
|
||||
|
||||
std::vector<Value*> tensor_indices;
|
||||
size_t dim = 0;
|
||||
|
||||
auto handle_tensor = [&](Value* tensor) {
|
||||
// NB: tensor_indices can have None holes because of how at::index works.
|
||||
tensor_indices.resize(dim + 1);
|
||||
tensor_indices[dim] = tensor;
|
||||
dim++;
|
||||
};
|
||||
|
||||
// before ellipsis, dimension index should be `dim`
|
||||
// after ellipsis, dimension index should be `-offset`
|
||||
int offset = 0;
|
||||
size_t ellipsis_dim = 0;
|
||||
auto insert_value_for_dim = [&](int64_t dim) {
|
||||
return (offset == 0)
|
||||
? graph->insertConstant(dim, nullptr, loc)
|
||||
:
|
||||
// NB: offset is incremented to move to the next dimension index
|
||||
graph->insertConstant(offset++, nullptr, loc);
|
||||
return graph->insertConstant(dim, nullptr, loc);
|
||||
};
|
||||
std::vector<int64_t> dims(subscript_exprs.size());
|
||||
std::vector<c10::optional<Value*>> exprs(
|
||||
subscript_exprs.size(), c10::nullopt);
|
||||
|
||||
for (const auto& subscript_expr : subscript_exprs) {
|
||||
// NB: ellipsis_dim is **always** incremented
|
||||
// (comparing to dim) in order to compute
|
||||
// the correct offsets for the remaining
|
||||
// dimension indices following an ellipsis "..."
|
||||
// token
|
||||
ellipsis_dim++;
|
||||
if (subscript_expr.kind() == TK_DOTS) {
|
||||
offset = -(subscript_exprs.size() - ellipsis_dim);
|
||||
++dim;
|
||||
continue;
|
||||
}
|
||||
auto handle_indexing = [&](const Expr& subscript_expr,
|
||||
int expr_idx,
|
||||
int64_t dim,
|
||||
bool is_reverse = false) {
|
||||
dims[expr_idx] = dim;
|
||||
if (subscript_expr.kind() == TK_SLICE_EXPR) {
|
||||
auto dim_val = insert_value_for_dim(dim);
|
||||
sliceable =
|
||||
emitSlice(loc, sliceable, dim_val, SliceExpr(subscript_expr));
|
||||
++dim;
|
||||
continue;
|
||||
if (is_reverse) {
|
||||
return dim - 1;
|
||||
} else {
|
||||
return dim + 1;
|
||||
}
|
||||
}
|
||||
TypePtr type_hint = OptionalType::ofTensor();
|
||||
if (subscript_expr.kind() == TK_NONE) {
|
||||
type_hint = NoneType::get();
|
||||
}
|
||||
auto index = emitExpr(subscript_expr, type_hint);
|
||||
if (index->type() == IntType::get()) {
|
||||
// NB: note, select squeezes out a dimension,
|
||||
// so dim is **not** incremented
|
||||
auto dim_val = insert_value_for_dim(dim);
|
||||
sliceable = emitSelect(loc, sliceable, dim_val, index);
|
||||
continue;
|
||||
} else if (index->type()->isSubtypeOf(NoneType::get())) {
|
||||
sliceable = emitUnsqueeze(loc, sliceable, dim);
|
||||
dim++;
|
||||
continue;
|
||||
exprs[expr_idx] = index;
|
||||
if (index->type()->isSubtypeOf(NoneType::get())) {
|
||||
if (is_reverse) {
|
||||
return dim;
|
||||
} else {
|
||||
return dim + 1;
|
||||
}
|
||||
} else if (index->type() == IntType::get()) {
|
||||
if (is_reverse) {
|
||||
return dim - 1;
|
||||
} else {
|
||||
return dim;
|
||||
}
|
||||
} else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
|
||||
// NB:index type can either be a Tensor or : (None of Optional Tensor)
|
||||
handle_tensor(index);
|
||||
if (is_reverse) {
|
||||
throw ErrorReport(loc)
|
||||
<< "Ellipses followed by tensor indexing is currently not supported";
|
||||
} else {
|
||||
return dim + 1;
|
||||
}
|
||||
} else {
|
||||
throw ErrorReport(loc)
|
||||
<< "Unsupported operation: indexing tensor with unsupported index type '"
|
||||
<< index->type()->python_str()
|
||||
<< "'. Only ints, slices, and tensors are supported";
|
||||
}
|
||||
};
|
||||
|
||||
size_t idx = 0;
|
||||
int64_t dim = 0;
|
||||
for (; idx < subscript_exprs.size(); idx++) {
|
||||
auto subscript_expr = subscript_exprs[idx];
|
||||
if (subscript_expr.kind() == TK_DOTS) {
|
||||
break;
|
||||
}
|
||||
dim = handle_indexing(subscript_expr, idx, dim, /*is_reverse=*/false);
|
||||
}
|
||||
int64_t rdim = -1;
|
||||
for (size_t rev_idx = subscript_exprs.size() - 1; rev_idx > idx;
|
||||
rev_idx--) {
|
||||
auto subscript_expr = subscript_exprs[rev_idx];
|
||||
if (subscript_expr.kind() == TK_DOTS) {
|
||||
throw ErrorReport(loc)
|
||||
<< "An index can only have a single ellipsis ('...')";
|
||||
}
|
||||
rdim =
|
||||
handle_indexing(subscript_expr, rev_idx, rdim, /*is_reverse=*/true);
|
||||
}
|
||||
for (size_t i = 0; i < exprs.size(); i++) {
|
||||
if (!exprs[i].has_value()) {
|
||||
if (subscript_exprs[i].kind() == TK_SLICE_EXPR) {
|
||||
sliceable = emitSlice(
|
||||
loc,
|
||||
sliceable,
|
||||
insert_value_for_dim(dims[i]),
|
||||
SliceExpr(subscript_exprs[i]));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
throw ErrorReport(loc)
|
||||
<< "Unsupported operation: indexing tensor with unsupported index type '"
|
||||
<< index->type()->python_str()
|
||||
<< "'. Only ints, slices, and tensors are supported";
|
||||
auto expr = exprs[i].value();
|
||||
if (expr->type()->isSubtypeOf(NoneType::get())) {
|
||||
sliceable =
|
||||
emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i]));
|
||||
} else if (expr->type() == IntType::get()) {
|
||||
sliceable =
|
||||
emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr);
|
||||
} else if (expr->type()->isSubtypeOf(OptionalType::ofTensor())) {
|
||||
tensor_indices.resize(dims[i] + 1);
|
||||
tensor_indices[dims[i]] = expr;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Trying to process index type that we don't support.");
|
||||
}
|
||||
}
|
||||
// at::index takes in a List[Optional[Tensor]] where some dims can be None.
|
||||
// create None node with optional tensor output type and pass to at::index.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user