diff --git a/AK/HashMap.h b/AK/HashMap.h index 731aa34680..bd2ccdd633 100644 --- a/AK/HashMap.h +++ b/AK/HashMap.h @@ -276,14 +276,9 @@ public: } template - V& ensure(K const& key, Callback initialization_callback) + V& ensure(K const& key, Callback initialization_callback, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Keep) { - auto it = find(key); - if (it != end()) - return it->value; - auto result = set(key, initialization_callback()); - VERIFY(result == HashSetResult::InsertedNewEntry); - return find(key)->value; + return m_table.ensure(KeyTraits::hash(key), [&](auto& entry) { return KeyTraits::equals(entry.key, key); }, [&] -> Entry { return { key, initialization_callback() }; }, existing_entry_behavior).value; } template diff --git a/AK/HashTable.h b/AK/HashTable.h index 8bb5cacabe..44aa24f756 100644 --- a/AK/HashTable.h +++ b/AK/HashTable.h @@ -392,6 +392,34 @@ public: return MUST(try_set(forward(value), existing_entry_behavior)); } + template + T& ensure(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace) + { + return MUST(try_set(forward(value), existing_entry_behavior)); + } + + template + [[nodiscard]] T& ensure(unsigned hash, TUnaryPredicate predicate, InitializationCallback initialization_callback, HashSetExistingEntryBehavior existing_entry_behavior) + { + if (should_grow()) + rehash(m_capacity * (100 + grow_capacity_increase_percent) / 100); + + auto [result, bucket] = lookup_for_writing(hash, move(predicate), existing_entry_behavior); + switch (result) { + case HashSetResult::InsertedNewEntry: + new (bucket.slot()) T(initialization_callback()); + break; + case HashSetResult::ReplacedExistingEntry: + (*bucket.slot()) = T(initialization_callback()); + break; + case HashSetResult::KeptExistingEntry: + break; + default: + __builtin_unreachable(); + } + return *bucket.slot(); + } + template [[nodiscard]] Iterator find(unsigned hash, TUnaryPredicate predicate) { @@ -642,8 +670,13 @@ private: return static_cast(probe_length + 1); } - template - HashSetResult write_value(U&& value, HashSetExistingEntryBehavior existing_entry_behavior) + struct LookupForWritingResult { + HashSetResult result; + BucketType& bucket; + }; + + template + LookupForWritingResult lookup_for_writing(u32 const hash, TUnaryPredicate predicate, HashSetExistingEntryBehavior existing_entry_behavior) { auto update_collection_for_new_bucket = [&](BucketType& bucket) { if constexpr (IsOrdered) { @@ -685,7 +718,6 @@ private: } }; - u32 const hash = TraitsForT::hash(value); auto bucket_index = hash % m_capacity; size_t probe_length = 0; for (;;) { @@ -693,22 +725,19 @@ private: // We found a free bucket, write to it and stop if (bucket->state == BucketState::Free) { - new (bucket->slot()) T(forward(value)); bucket->state = bucket_state_for_probe_length(probe_length); bucket->hash.set(hash); update_collection_for_new_bucket(*bucket); ++m_size; - return HashSetResult::InsertedNewEntry; + return { HashSetResult::InsertedNewEntry, *bucket }; } // The bucket is already used, does it have an identical value? - if (bucket->hash.check(hash) - && TraitsForT::equals(*bucket->slot(), static_cast(value))) { + if (bucket->hash.check(hash) && predicate(*bucket->slot())) { if (existing_entry_behavior == HashSetExistingEntryBehavior::Replace) { - (*bucket->slot()) = forward(value); - return HashSetResult::ReplacedExistingEntry; + return { HashSetResult::ReplacedExistingEntry, *bucket }; } - return HashSetResult::KeptExistingEntry; + return { HashSetResult::KeptExistingEntry, *bucket }; } // Robin hood: if our probe length is larger (poor) than this bucket's (rich), steal its position! @@ -720,7 +749,7 @@ private: update_collection_for_swapped_buckets(bucket, &bucket_to_move); // Write new bucket - new (bucket->slot()) T(forward(value)); + BucketType* inserted_bucket = bucket; bucket->state = bucket_state_for_probe_length(probe_length); bucket->hash.set(hash); probe_length = target_probe_length; @@ -752,7 +781,7 @@ private: } } - return HashSetResult::InsertedNewEntry; + return { HashSetResult::InsertedNewEntry, *inserted_bucket }; } // Try next bucket @@ -762,6 +791,26 @@ private: } } + template + HashSetResult write_value(U&& value, HashSetExistingEntryBehavior existing_entry_behavior) + { + u32 const hash = TraitsForT::hash(value); + auto [result, bucket] = lookup_for_writing(hash, [&](auto& candidate) { return TraitsForT::equals(candidate, static_cast(value)); }, existing_entry_behavior); + switch (result) { + case HashSetResult::ReplacedExistingEntry: + (*bucket.slot()) = forward(value); + break; + case HashSetResult::InsertedNewEntry: + new (bucket.slot()) T(forward(value)); + break; + case HashSetResult::KeptExistingEntry: + break; + default: + __builtin_unreachable(); + } + return result; + } + void delete_bucket(auto& bucket) { VERIFY(bucket.state != BucketState::Free);