[pt-vulkan][ez] Replace ska::flat_hash_map, c10::get_hash with std::unordered_map, std::hash (#117177)

## Context

This change is part of a set of changes that removes all references to the `c10` library in the `api/`, `graph/`, and `impl/` folders of the PyTorch Vulkan codebase. This is to ensure that these components can be built as a standalone library such that they can be used as the foundations of a Android GPU delegate for ExecuTorch.

## Notes for Reviewers

The majority of the changes in this changeset are:

* Replacing instances of `ska::flat_hash_map` with `std::unordered_map`
   * `ska::flat_hash_map` is an optimized hash map, but the optimizations shouldn't be too impactful so `std::unordered_map` should suffice. Performance regression testing will be done at the final change in this stack to verify this.
* Replacing `c10::get_hash` with `std::hash` where only one variable is getting hashed or the `utils::hash_combine()` function added to `api/Utils.h` (which was copied from `c10/util/hash.h`)

Differential Revision: [D52662231](https://our.internmc.facebook.com/intern/diff/D52662231/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117177
Approved by: https://github.com/yipjustin
ghstack dependencies: #117176
This commit is contained in:
SS-JIA 2024-01-11 13:03:34 -08:00 committed by PyTorch MergeBot
parent 57b76b970b
commit fe298e901a
9 changed files with 60 additions and 26 deletions

View File

@ -1,6 +1,7 @@
#include <ATen/native/vulkan/api/Descriptor.h>
#include <ATen/native/vulkan/api/Utils.h>
#include <algorithm>
#include <utility>
namespace at {

View File

@ -7,7 +7,7 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Resource.h>
#include <ATen/native/vulkan/api/Shader.h>
#include <c10/util/flat_hash_map.h>
#include <unordered_map>
namespace at {
namespace native {
@ -114,7 +114,7 @@ class DescriptorPool final {
DescriptorPoolConfig config_;
// New Descriptors
std::mutex mutex_;
ska::flat_hash_map<VkDescriptorSetLayout, DescriptorSetPile> piles_;
std::unordered_map<VkDescriptorSetLayout, DescriptorSetPile> piles_;
public:
DescriptorSet get_descriptor_set(

View File

@ -7,7 +7,7 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Resource.h>
#include <ATen/native/vulkan/api/Shader.h>
#include <c10/util/flat_hash_map.h>
#include <unordered_map>
#include <mutex>
@ -124,7 +124,7 @@ class PipelineLayoutCache final {
struct Hasher {
inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const {
return c10::get_hash(descriptor_layout);
return std::hash<VkDescriptorSetLayout>()(descriptor_layout);
}
};
@ -134,7 +134,7 @@ class PipelineLayoutCache final {
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
std::unordered_map<Key, Value, Hasher> cache_;
public:
VkPipelineLayout retrieve(const Key&);
@ -159,12 +159,19 @@ class ComputePipelineCache final {
struct Hasher {
inline size_t operator()(
const ComputePipeline::Descriptor& descriptor) const {
return c10::get_hash(
descriptor.pipeline_layout,
descriptor.shader_module,
descriptor.local_work_group.data[0u],
descriptor.local_work_group.data[1u],
descriptor.local_work_group.data[2u]);
size_t seed = 0;
seed = utils::hash_combine(
seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
seed = utils::hash_combine(
seed, std::hash<VkShaderModule>()(descriptor.shader_module));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));
return seed;
}
};
@ -175,7 +182,7 @@ class ComputePipelineCache final {
VkDevice device_;
VkPipelineCache pipeline_cache_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
std::unordered_map<Key, Value, Hasher> cache_;
public:
VkPipeline retrieve(const Key&);

View File

@ -6,6 +6,7 @@
#endif // USE_KINETO
#include <cmath>
#include <iomanip>
#include <iostream>
#include <utility>

View File

@ -304,8 +304,15 @@ ImageSampler::~ImageSampler() {
size_t ImageSampler::Hasher::operator()(
const ImageSampler::Properties& props) const {
return c10::get_hash(
props.filter, props.mipmap_mode, props.address_mode, props.border_color);
size_t seed = 0;
seed = utils::hash_combine(seed, std::hash<VkFilter>()(props.filter));
seed = utils::hash_combine(
seed, std::hash<VkSamplerMipmapMode>()(props.mipmap_mode));
seed = utils::hash_combine(
seed, std::hash<VkSamplerAddressMode>()(props.address_mode));
seed =
utils::hash_combine(seed, std::hash<VkBorderColor>()(props.border_color));
return seed;
}
void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept {

View File

@ -8,10 +8,10 @@
#include <ATen/native/vulkan/api/Utils.h>
#include <c10/core/ScalarType.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/typeid.h>
#include <stack>
#include <unordered_map>
namespace at {
namespace native {
@ -359,7 +359,7 @@ class SamplerCache final {
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
std::unordered_map<Key, Value, Hasher> cache_;
public:
VkSampler retrieve(const Key&);

View File

@ -7,10 +7,9 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/Utils.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/hash.h>
#include <mutex>
#include <unordered_map>
namespace at {
namespace native {
@ -128,7 +127,8 @@ class ShaderLayoutCache final {
size_t hashed = 0u;
for (const VkDescriptorType type : signature) {
hashed = c10::hash_combine(hashed, c10::get_hash(type));
hashed =
utils::hash_combine(hashed, std::hash<VkDescriptorType>()(type));
}
return hashed;
@ -141,7 +141,7 @@ class ShaderLayoutCache final {
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
std::unordered_map<Key, Value, Hasher> cache_;
public:
VkDescriptorSetLayout retrieve(const Key&);
@ -165,7 +165,13 @@ class ShaderCache final {
struct Hasher {
inline size_t operator()(const ShaderInfo& source) const {
return c10::get_hash(source.src_code.bin, source.src_code.size);
size_t seed = 0;
seed = utils::hash_combine(
seed, std::hash<const uint32_t*>()(source.src_code.bin));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(source.src_code.size));
return seed;
}
};
@ -175,7 +181,7 @@ class ShaderCache final {
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
std::unordered_map<Key, Value, Hasher> cache_;
public:
VkShaderModule retrieve(const Key&);

View File

@ -23,6 +23,18 @@ namespace vulkan {
namespace api {
namespace utils {
//
// Hashing
//
/**
* hash_combine is taken from c10/util/hash.h, which in turn is based on
* implementation from Boost
*/
inline size_t hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
}
//
// Alignment
//

View File

@ -520,8 +520,8 @@ def gen_cpp_files(
h = "#pragma once\n"
h += "#include <ATen/native/vulkan/api/Types.h>\n"
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
h += "#include <c10/util/flat_hash_map.h>\n"
h += "#include <string>\n"
h += "#include <unordered_map>\n"
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
@ -533,9 +533,9 @@ def gen_cpp_files(
# Forward declaration of ShaderInfo
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
h += "typedef ska::flat_hash_map<std::string, api::ShaderInfo> ShaderListing;\n"
h += "typedef ska::flat_hash_map<std::string, std::string> RegistryKeyMap;\n"
h += "typedef ska::flat_hash_map<std::string, RegistryKeyMap> ShaderRegistry;\n"
h += "typedef std::unordered_map<std::string, api::ShaderInfo> ShaderListing;\n"
h += "typedef std::unordered_map<std::string, std::string> RegistryKeyMap;\n"
h += "typedef std::unordered_map<std::string, RegistryKeyMap> ShaderRegistry;\n"
h += "extern const ShaderListing shader_infos;\n"
h += "extern ShaderRegistry shader_registry;\n"
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"