mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PyTorch] Optimize DictType::annotation_str_impl (#96498)
stringstream construction is expensive, and we can exactly reserve space for the output string while doing the same number of string copies. (If we wanted to improve performance further, we could introduce annotation_str_out to append the output to a given std::string and thus avoid copying subtype annotation strings. It occurs to me that the existing approach is quadratic in the number of layers of nesting, so we should probably do this!) Differential Revision: [D43919651](https://our.internmc.facebook.com/intern/diff/D43919651/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/96498 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
000cfeb848
commit
a66625da3b
|
|
@ -999,12 +999,7 @@ struct TORCH_API DictType : public SharedType {
|
|||
types.push_back(std::move(value));
|
||||
}
|
||||
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", ";
|
||||
ss << getValueType()->annotation_str(std::move(printer)) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
|
||||
|
||||
std::vector<TypePtr> types;
|
||||
bool has_free_variables;
|
||||
|
|
|
|||
|
|
@ -304,6 +304,21 @@ TypePtr DictType::get(std::string identifier, TypePtr key, TypePtr value) {
|
|||
return containerTypePtrs[map_key];
|
||||
}
|
||||
|
||||
std::string DictType::annotation_str_impl(TypePrinter printer) const {
|
||||
auto keyAnnotation = getKeyType()->annotation_str(printer);
|
||||
auto valueAnnotation = getValueType()->annotation_str(std::move(printer));
|
||||
|
||||
std::string result;
|
||||
result.reserve(5 /* "Dict[" */ + keyAnnotation.size() + 2 /* ", " */ + valueAnnotation.size() + 1 /* "]" */);
|
||||
result = "Dict[";
|
||||
result += keyAnnotation;
|
||||
result.push_back(',');
|
||||
result.push_back(' ');
|
||||
result += valueAnnotation;
|
||||
result.push_back(']');
|
||||
return result;
|
||||
}
|
||||
|
||||
AnyListTypePtr AnyListType::get() {
|
||||
static AnyListTypePtr value(new AnyListType());
|
||||
return value;
|
||||
|
|
|
|||
|
|
@ -11528,6 +11528,10 @@ dedent """
|
|||
tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()])
|
||||
self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]")
|
||||
|
||||
def test_dict_str(self):
|
||||
dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get())
|
||||
self.assertEqual(dict_type.annotation_str, "Dict[str, str]")
|
||||
|
||||
def test_none_type_str(self):
|
||||
none_type = torch._C.NoneType.get()
|
||||
g = {'NoneType' : type(None)}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user