mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Refactor ElementalIrEmitter's slice index finding code into
IrArray::Index::SourceIndexOfSlice(). PiperOrigin-RevId: 161140653
This commit is contained in:
parent
ba297aec99
commit
f9c9cacb06
|
|
@ -948,23 +948,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||||
case HloOpcode::kSlice:
|
case HloOpcode::kSlice:
|
||||||
return [this, hlo, &operand_to_generator](
|
return [this, hlo, &operand_to_generator](
|
||||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||||
IrArray::Index sliced_index(index.size());
|
IrArray::Index sliced_index = index.SourceIndexOfSlice(
|
||||||
for (int i = 0; i < index.size(); ++i) {
|
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
|
||||||
int64 stride = hlo->slice_stride(i);
|
/*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
|
||||||
if (stride != 1) {
|
|
||||||
sliced_index[i] = ir_builder_->CreateAdd(
|
|
||||||
ir_builder_->CreateMul(
|
|
||||||
index[i], llvm::ConstantInt::get(index[i]->getType(),
|
|
||||||
stride)),
|
|
||||||
llvm::ConstantInt::get(index[i]->getType(),
|
|
||||||
hlo->slice_starts(i)));
|
|
||||||
} else {
|
|
||||||
sliced_index[i] = ir_builder_->CreateAdd(
|
|
||||||
index[i],
|
|
||||||
llvm::ConstantInt::get(index[i]->getType(),
|
|
||||||
hlo->slice_starts(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return operand_to_generator.at(hlo->operand(0))(sliced_index);
|
return operand_to_generator.at(hlo->operand(0))(sliced_index);
|
||||||
};
|
};
|
||||||
case HloOpcode::kDynamicSlice:
|
case HloOpcode::kDynamicSlice:
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,28 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||||
return Index(source_multidim_index);
|
return Index(source_multidim_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IrArray::Index IrArray::Index::SourceIndexOfSlice(
|
||||||
|
const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> strides,
|
||||||
|
llvm::IRBuilder<>* builder) const {
|
||||||
|
Index source_index(multidim_.size());
|
||||||
|
for (int i = 0; i < multidim_.size(); ++i) {
|
||||||
|
int64 stride = strides[i];
|
||||||
|
auto type = multidim_[i]->getType();
|
||||||
|
|
||||||
|
if (stride != 1) {
|
||||||
|
source_index[i] = builder->CreateAdd(
|
||||||
|
builder->CreateMul(multidim_[i],
|
||||||
|
llvm::ConstantInt::get(type, stride)),
|
||||||
|
llvm::ConstantInt::get(type, starts[i]));
|
||||||
|
} else {
|
||||||
|
source_index[i] = builder->CreateAdd(
|
||||||
|
multidim_[i], llvm::ConstantInt::get(type, starts[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return source_index;
|
||||||
|
}
|
||||||
|
|
||||||
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
||||||
const Shape& shape, const Shape& operand_shape,
|
const Shape& shape, const Shape& operand_shape,
|
||||||
tensorflow::gtl::ArraySlice<int64> dimension_mapping,
|
tensorflow::gtl::ArraySlice<int64> dimension_mapping,
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,16 @@ class IrArray {
|
||||||
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
|
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
|
||||||
llvm::IRBuilder<>* builder) const;
|
llvm::IRBuilder<>* builder) const;
|
||||||
|
|
||||||
|
// Returns the index into the source operand from which a slice operation
|
||||||
|
// selects a value to be placed into index "this". The slice is described
|
||||||
|
// by starting indices `starts` and stride values `strides`.
|
||||||
|
//
|
||||||
|
// Precondition: "this" is an index into a slice whose shape is `shape`.
|
||||||
|
Index SourceIndexOfSlice(const Shape& shape,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> starts,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> strides,
|
||||||
|
llvm::IRBuilder<>* builder) const;
|
||||||
|
|
||||||
// Given that "this" is the target index of a transpose from `operand_shape`
|
// Given that "this" is the target index of a transpose from `operand_shape`
|
||||||
// to `shape` with the given dimension mapping, returns the source index.
|
// to `shape` with the given dimension mapping, returns the source index.
|
||||||
Index SourceIndexOfTranspose(
|
Index SourceIndexOfTranspose(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user