Replace std::function object with regular function.

The function is called recursively, and the std::function object had only existed to allow recursion from within a lambda expression. A regular function should be cheaper than a polymorphic function wrapper.

PiperOrigin-RevId: 158292415
This commit is contained in:
A. Unique TensorFlower 2017-06-07 11:20:44 -07:00 committed by TensorFlower Gardener
parent ba656b2611
commit 94085bee74
2 changed files with 71 additions and 41 deletions

View File

@ -1306,7 +1306,7 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
void HloInstruction::DetachFromOperands() {
CHECK_EQ(0, user_count());
// An intruction may be repeated as an operand. To avoid calling RemoveUser
// An instruction may be repeated as an operand. To avoid calling RemoveUser
// twice on the same operand, keep a set of already detached operands.
std::set<HloInstruction*> detached_operands;
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
@ -2162,6 +2162,70 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
return true;
}
// A helper class for memoized, recursive computation of HloOpcode::kFusion
// in HloInstruction::OperandElementUse below.
class HloInstruction::FusionReusesParamElements {
public:
using UseKind = HloInstruction::UseKind;
// We could rather iterate backwards thru fused_instructions_ here, as it is
// in reverse postorder, and compute whether each fused instruction reuses the
// value of this parameter, which would save stack space but not allow us to
// finish early if we find a reuse.
static UseKind Compute(int64 i, const HloInstruction& hlo) {
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
return ComputeInternal(i, hlo, &memoization_cache);
}
private:
static UseKind ComputeInternal(
int64 i, const HloInstruction& hlo,
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
return UseKind::kUse;
}
auto p = cache->emplace(&hlo, UseKind{});
auto value_it = p.first;
const bool key_is_new = p.second;
if (key_is_new) {
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
UseKind old_val = value_it->second;
// The next operation invalidates iterators.
UseKind new_val =
Plus(old_val, std::min(hlo.OperandElementUse(j),
ComputeInternal(i, *hlo.operand(j), cache)));
// Re-acquire the iterator. We could work harder to do this only if
// absolutely necessary, but this code is not hot enough to warrant
// that.
value_it = cache->find(&hlo);
value_it->second = new_val;
}
}
return value_it->second;
}
// Fold operation for UseKinds.
static UseKind Plus(UseKind a, UseKind b) {
if (a == UseKind::kNoUse) {
return b;
} else if (b == UseKind::kNoUse) {
return a;
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
return UseKind::kReuse;
} else if (a == UseKind::kUsePermutingElements ||
b == UseKind::kUsePermutingElements) {
return UseKind::kReuse;
} else {
CHECK(a == UseKind::kUse && b == UseKind::kUse);
return UseKind::kUse;
}
}
};
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
switch (opcode_) {
case HloOpcode::kBitcast:
@ -2176,46 +2240,9 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
// Pad reuses the padding value but not the padded array elements.
// Reduce reuses the init value but not the operand array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
case HloOpcode::kFusion: {
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> cache;
// We could rather iterate backwards thru fused_instructions_ here, as it
// is in reverse postorder, and compute whether each fused instruction
// reuses the value of this parameter, which would save stack space but
// not allow us to finish early if we find a reuse.
std::function<UseKind(const HloInstruction&)> reuses_parameter_elements =
[i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) {
auto plus = [](const UseKind& a, const UseKind& b) {
if (a == UseKind::kNoUse) {
return b;
} else if (b == UseKind::kNoUse) {
return a;
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
return UseKind::kReuse;
} else if (a == UseKind::kUsePermutingElements ||
b == UseKind::kUsePermutingElements) {
return UseKind::kReuse;
}
CHECK(UseKind::kUse == a && UseKind::kUse == b);
return UseKind::kUse;
};
if (hlo.opcode_ == HloOpcode::kParameter &&
hlo.parameter_number_ == i) {
return UseKind::kUse;
}
if (!ContainsKey(cache, &hlo)) {
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
UseKind old = cache[&hlo];
UseKind updated = plus(
old, std::min(hlo.OperandElementUse(j),
reuses_parameter_elements(*hlo.operand(j))));
cache[&hlo] = updated;
}
}
return cache[&hlo];
};
return reuses_parameter_elements(*fused_expression_root());
}
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(i, *fused_expression_root());
default:
return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
}

View File

@ -775,6 +775,9 @@ class HloInstruction {
private:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
// Helper class for computing OperandElementUse for kFusion.
class FusionReusesParamElements;
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(
const Shape& shape, HloOpcode opcode,