diff --git a/c10/test/util/string_util_test.cpp b/c10/test/util/string_util_test.cpp index 44dc2b794db..c928809c9a1 100644 --- a/c10/test/util/string_util_test.cpp +++ b/c10/test/util/string_util_test.cpp @@ -133,4 +133,34 @@ TEST(tryToTest, Double) { EXPECT_FALSE(c10::tryToNumber(nullptr).has_value()); } } // namespace test_try_to + +namespace test_split { +TEST(SplitTest, NormalCase) { + std::string str = "torch.ops.aten.linear"; + auto result = c10::split(str, '.'); + ASSERT_EQ(4, result.size()); + EXPECT_EQ("torch", result[0]); + EXPECT_EQ("ops", result[1]); + EXPECT_EQ("aten", result[2]); + EXPECT_EQ("linear", result[3]); +} +TEST(SplitTest, EmptyString) { + auto result = c10::split("", '.'); + EXPECT_TRUE(result.empty()); +} +TEST(SplitTest, NoDelimiter) { + std::string str = "single"; + auto result = c10::split(str, '.'); + ASSERT_EQ(1, result.size()); + EXPECT_EQ("single", result[0]); +} +TEST(SplitTest, ConsecutiveDelimiters) { + std::string str = "atom1..atom2"; + auto result = c10::split(str, '.'); + ASSERT_EQ(3, result.size()); + EXPECT_EQ("atom1", result[0]); + EXPECT_EQ("", result[1]); + EXPECT_EQ("atom2", result[2]); +} +} // namespace test_split } // namespace diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index facc2e963b5..063a8fc93ea 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -200,4 +200,19 @@ std::optional tryToNumber(const char* symbol) { return value; } +std::vector split(std::string_view target, char delimiter) { + std::vector atoms; + std::string_view buffer = target; + while (!buffer.empty()) { + auto i = buffer.find(delimiter); + if (i == std::string_view::npos) { + atoms.push_back(buffer); + buffer.remove_prefix(buffer.size()); + } else { + atoms.push_back(buffer.substr(0, i)); + buffer.remove_prefix(i + 1); + } + } + return atoms; +} } // namespace c10 diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index aa545040b5b..0c81bc6f62c 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -10,6 +10,7 @@ #include #include #include +#include C10_CLANG_DIAGNOSTIC_PUSH() #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") @@ -238,6 +239,9 @@ C10_API std::optional tryToNumber(const char* symbol); template <> C10_API std::optional tryToNumber(const std::string& symbol); +C10_API std::vector split( + std::string_view target, + char delimiter); } // namespace c10 C10_CLANG_DIAGNOSTIC_POP()