mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Replace std::function object with regular function.
The function is called recursively, and the std::function object had only existed to allow recursion from within a lambda expression. A regular function should be cheaper than a polymorphic function wrapper. PiperOrigin-RevId: 158292415
This commit is contained in:
parent
ba656b2611
commit
94085bee74
|
|
@ -1306,7 +1306,7 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
|
|||
|
||||
void HloInstruction::DetachFromOperands() {
|
||||
CHECK_EQ(0, user_count());
|
||||
// An intruction may be repeated as an operand. To avoid calling RemoveUser
|
||||
// An instruction may be repeated as an operand. To avoid calling RemoveUser
|
||||
// twice on the same operand, keep a set of already detached operands.
|
||||
std::set<HloInstruction*> detached_operands;
|
||||
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
|
||||
|
|
@ -2162,6 +2162,70 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
// A helper class for memoized, recursive computation of HloOpcode::kFusion
|
||||
// in HloInstruction::OperandElementUse below.
|
||||
class HloInstruction::FusionReusesParamElements {
|
||||
public:
|
||||
using UseKind = HloInstruction::UseKind;
|
||||
|
||||
// We could rather iterate backwards thru fused_instructions_ here, as it is
|
||||
// in reverse postorder, and compute whether each fused instruction reuses the
|
||||
// value of this parameter, which would save stack space but not allow us to
|
||||
// finish early if we find a reuse.
|
||||
static UseKind Compute(int64 i, const HloInstruction& hlo) {
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
|
||||
return ComputeInternal(i, hlo, &memoization_cache);
|
||||
}
|
||||
|
||||
private:
|
||||
static UseKind ComputeInternal(
|
||||
int64 i, const HloInstruction& hlo,
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
|
||||
if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
|
||||
return UseKind::kUse;
|
||||
}
|
||||
|
||||
auto p = cache->emplace(&hlo, UseKind{});
|
||||
auto value_it = p.first;
|
||||
const bool key_is_new = p.second;
|
||||
|
||||
if (key_is_new) {
|
||||
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
|
||||
UseKind old_val = value_it->second;
|
||||
|
||||
// The next operation invalidates iterators.
|
||||
UseKind new_val =
|
||||
Plus(old_val, std::min(hlo.OperandElementUse(j),
|
||||
ComputeInternal(i, *hlo.operand(j), cache)));
|
||||
|
||||
// Re-acquire the iterator. We could work harder to do this only if
|
||||
// absolutely necessary, but this code is not hot enough to warrant
|
||||
// that.
|
||||
value_it = cache->find(&hlo);
|
||||
value_it->second = new_val;
|
||||
}
|
||||
}
|
||||
return value_it->second;
|
||||
}
|
||||
|
||||
// Fold operation for UseKinds.
|
||||
static UseKind Plus(UseKind a, UseKind b) {
|
||||
if (a == UseKind::kNoUse) {
|
||||
return b;
|
||||
} else if (b == UseKind::kNoUse) {
|
||||
return a;
|
||||
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
|
||||
return UseKind::kReuse;
|
||||
} else if (a == UseKind::kUsePermutingElements ||
|
||||
b == UseKind::kUsePermutingElements) {
|
||||
return UseKind::kReuse;
|
||||
} else {
|
||||
CHECK(a == UseKind::kUse && b == UseKind::kUse);
|
||||
return UseKind::kUse;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kBitcast:
|
||||
|
|
@ -2176,46 +2240,9 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
|
|||
// Pad reuses the padding value but not the padded array elements.
|
||||
// Reduce reuses the init value but not the operand array elements.
|
||||
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
||||
case HloOpcode::kFusion: {
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> cache;
|
||||
// We could rather iterate backwards thru fused_instructions_ here, as it
|
||||
// is in reverse postorder, and compute whether each fused instruction
|
||||
// reuses the value of this parameter, which would save stack space but
|
||||
// not allow us to finish early if we find a reuse.
|
||||
std::function<UseKind(const HloInstruction&)> reuses_parameter_elements =
|
||||
[i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) {
|
||||
auto plus = [](const UseKind& a, const UseKind& b) {
|
||||
if (a == UseKind::kNoUse) {
|
||||
return b;
|
||||
} else if (b == UseKind::kNoUse) {
|
||||
return a;
|
||||
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
|
||||
return UseKind::kReuse;
|
||||
} else if (a == UseKind::kUsePermutingElements ||
|
||||
b == UseKind::kUsePermutingElements) {
|
||||
return UseKind::kReuse;
|
||||
}
|
||||
CHECK(UseKind::kUse == a && UseKind::kUse == b);
|
||||
return UseKind::kUse;
|
||||
};
|
||||
|
||||
if (hlo.opcode_ == HloOpcode::kParameter &&
|
||||
hlo.parameter_number_ == i) {
|
||||
return UseKind::kUse;
|
||||
}
|
||||
if (!ContainsKey(cache, &hlo)) {
|
||||
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
|
||||
UseKind old = cache[&hlo];
|
||||
UseKind updated = plus(
|
||||
old, std::min(hlo.OperandElementUse(j),
|
||||
reuses_parameter_elements(*hlo.operand(j))));
|
||||
cache[&hlo] = updated;
|
||||
}
|
||||
}
|
||||
return cache[&hlo];
|
||||
};
|
||||
return reuses_parameter_elements(*fused_expression_root());
|
||||
}
|
||||
case HloOpcode::kFusion:
|
||||
// Uses the memoizing, recursive computation defined above.
|
||||
return FusionReusesParamElements::Compute(i, *fused_expression_root());
|
||||
default:
|
||||
return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -775,6 +775,9 @@ class HloInstruction {
|
|||
private:
|
||||
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
|
||||
|
||||
// Helper class for computing OperandElementUse for kFusion.
|
||||
class FusionReusesParamElements;
|
||||
|
||||
// Creates an n-ary elementwise operation.
|
||||
static std::unique_ptr<HloInstruction> CreateNary(
|
||||
const Shape& shape, HloOpcode opcode,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user