mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Relanding shape cache (75400) (#75710)
Summary: https://github.com/pytorch/pytorch/pull/75400 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75710 Reviewed By: malfet Differential Revision: D35598920 Pulled By: Krovatkin fbshipit-source-id: 2bbbb3d0c24214b5dbb4ca605e7daa94671f96b0 (cherry picked from commit 572f2f9df5bfd73cd7b83536f619bc86d820ccd8)
This commit is contained in:
parent
db1801099b
commit
ce842f43f2
|
|
@ -467,6 +467,14 @@ struct TORCH_API SymbolicShape {
|
|||
// result will be unranked.
|
||||
SymbolicShape merge(const SymbolicShape& other) const;
|
||||
|
||||
friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
|
||||
return lhs.dims_ == rhs.dims_;
|
||||
}
|
||||
|
||||
friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<std::vector<ShapeSymbol>> dims_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
|
|
@ -293,27 +295,34 @@ TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
|
|||
|
||||
namespace {
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
|
||||
void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
|
||||
auto a_canonical = CanonicalizedSymbolicShape(a);
|
||||
auto e_canonical = CanonicalizedSymbolicShape(e);
|
||||
EXPECT_EQ(a_canonical, e_canonical);
|
||||
}
|
||||
|
||||
void assertShapeEqual(
|
||||
c10::optional<std::vector<c10::SymbolicShape>>& actual,
|
||||
std::vector<c10::optional<int64_t>> expected) {
|
||||
ASSERT_TRUE(actual.has_value());
|
||||
ASSERT_EQ(actual->size(), 1);
|
||||
auto a_canonical = CanonicalizedSymbolicShape(actual->at(0));
|
||||
|
||||
auto symb_expected = c10::SymbolicShape(expected);
|
||||
auto b_canonical = CanonicalizedSymbolicShape(symb_expected);
|
||||
ASSERT_EQ(a_canonical, b_canonical);
|
||||
assertShapeEqual(actual->at(0), symb_expected);
|
||||
}
|
||||
|
||||
const FunctionSchema* getSchema(const char* name) {
|
||||
return &(getOperatorForLiteral(name)->schema());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
|
||||
// Figure out how to fetch a function schema
|
||||
|
||||
// Ask someone else how to create a function schema / operator in C++
|
||||
std::shared_ptr<Operator> op = getOperatorForLiteral(
|
||||
auto schema = getSchema(
|
||||
"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
|
||||
const FunctionSchema* schema = &(op->schema());
|
||||
|
||||
c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
|
||||
c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
|
||||
|
|
@ -352,5 +361,123 @@ TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
|
|||
assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
|
||||
}
|
||||
|
||||
TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
|
||||
clear_shape_cache();
|
||||
auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
|
||||
|
||||
c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
|
||||
c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
|
||||
c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
|
||||
|
||||
c10::optional<int64_t> sym_dim = c10::nullopt;
|
||||
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
|
||||
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
|
||||
c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
|
||||
|
||||
auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
|
||||
assertShapeEqual(res, {sym_dim, 56});
|
||||
auto res1_val = res->at(0);
|
||||
|
||||
// The exact same arguments should return the exact same result
|
||||
res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
|
||||
auto res2_val = res->at(0);
|
||||
EXPECT_EQ(res1_val, res2_val);
|
||||
EXPECT_EQ(get_shape_cache_size(), 1);
|
||||
|
||||
// Same shape but different symbols should return same shape
|
||||
// but different symbolic indicies
|
||||
res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
|
||||
auto res3_val = res->at(0);
|
||||
|
||||
assertShapeEqual(res3_val, res2_val);
|
||||
EXPECT_NE(res3_val, res2_val);
|
||||
EXPECT_EQ(get_shape_cache_size(), 1);
|
||||
|
||||
// Different concrete shape should be cached separately
|
||||
res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
|
||||
assertShapeEqual(res, {sym_dim, 20});
|
||||
EXPECT_EQ(get_shape_cache_size(), 2);
|
||||
|
||||
res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
|
||||
assertShapeEqual(res, {sym_dim, 20});
|
||||
EXPECT_EQ(get_shape_cache_size(), 3);
|
||||
|
||||
res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
|
||||
assertShapeEqual(res, {sym_dim, sym_dim});
|
||||
EXPECT_EQ(get_shape_cache_size(), 4);
|
||||
}
|
||||
|
||||
TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
|
||||
clear_shape_cache();
|
||||
|
||||
auto squeeze_op =
|
||||
getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
|
||||
auto mul_tensor =
|
||||
getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
|
||||
auto mul_scalar =
|
||||
getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
|
||||
auto div_tensor =
|
||||
getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
|
||||
auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
|
||||
|
||||
c10::IValue const_int = 1;
|
||||
|
||||
c10::optional<int64_t> sym_dim = c10::nullopt;
|
||||
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
|
||||
|
||||
auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
|
||||
assertShapeEqual(res, {sym_dim, 64});
|
||||
|
||||
// Show that cache can handle multiple functions
|
||||
res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
|
||||
assertShapeEqual(res, {sym_dim, 64});
|
||||
EXPECT_EQ(get_shape_cache_size(), 2);
|
||||
|
||||
res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
|
||||
assertShapeEqual(res, {sym_dim, 64});
|
||||
EXPECT_EQ(get_shape_cache_size(), 3);
|
||||
|
||||
// Even when the expected outcome is the same, should not collide
|
||||
res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
|
||||
assertShapeEqual(res, {sym_dim, 64});
|
||||
EXPECT_EQ(get_shape_cache_size(), 4);
|
||||
|
||||
// Don't lose cached objects
|
||||
res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
|
||||
assertShapeEqual(res, {sym_dim, 64});
|
||||
EXPECT_EQ(get_shape_cache_size(), 4);
|
||||
|
||||
res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
|
||||
// SSA can infer that sym_dim is 64 as both tensors
|
||||
// use the same sym_dim
|
||||
assertShapeEqual(res, {64, 64});
|
||||
EXPECT_EQ(get_shape_cache_size(), 5);
|
||||
}
|
||||
|
||||
TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
|
||||
clear_shape_cache();
|
||||
|
||||
auto max_dim_op = getSchema(
|
||||
"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
|
||||
c10::IValue const_int = 1;
|
||||
c10::IValue false_ival = false;
|
||||
|
||||
c10::optional<int64_t> sym_dim = c10::nullopt;
|
||||
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
|
||||
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
|
||||
|
||||
auto res =
|
||||
calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
|
||||
c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim});
|
||||
assertShapeEqual(res->at(0), expected_res);
|
||||
// res0 and res1 should share the same symbolic symbol
|
||||
EXPECT_EQ(res->at(0), res->at(1));
|
||||
|
||||
// Also test that the shape cache also returns consistent result shapes
|
||||
res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
|
||||
assertShapeEqual(res->at(0), expected_res);
|
||||
EXPECT_EQ(res->at(0), res->at(1));
|
||||
EXPECT_EQ(get_shape_cache_size(), 1);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -306,6 +306,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/passes/integer_value_refinement.cpp",
|
||||
"torch/csrc/jit/passes/replacement_of_old_operators.cpp",
|
||||
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
|
||||
"torch/csrc/jit/passes/symbolic_shape_cache.cpp",
|
||||
"torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp",
|
||||
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
||||
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/runtime/exception_message.h>
|
||||
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
||||
|
|
@ -174,6 +175,17 @@ bool symbolicShapeAnalysisTestModeEnabled() {
|
|||
return symbolic_shape_analysis_test_mode;
|
||||
}
|
||||
|
||||
using SSArgument = c10::variant<ShapeArguments, IValue>;
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
|
||||
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
|
||||
out << *iv;
|
||||
} else {
|
||||
out << c10::get<ShapeArguments>(sa);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool isListOfInts(const TypePtr& type) {
|
||||
|
|
@ -244,8 +256,6 @@ c10::SymbolicShape extractListShape(
|
|||
return c10::SymbolicShape(output_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Symbolic Shape Analysis works through iteratively partially evaluating
|
||||
// a TorchScript shape compute graph by inputing properties from input
|
||||
// Tensors. We can substitute in properties like `len(x)` and `x[1]`
|
||||
|
|
@ -260,17 +270,6 @@ c10::SymbolicShape extractListShape(
|
|||
// means that we do know its concrete value statically but we can asssign sets
|
||||
// of tensor dimensions which must be equal at runtime.
|
||||
|
||||
using SSArgument = c10::variant<ShapeArguments, IValue>;
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
|
||||
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
|
||||
out << *iv;
|
||||
} else {
|
||||
out << c10::get<ShapeArguments>(sa);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
struct SymbolicShapeOpAnalyzer {
|
||||
std::shared_ptr<Graph> shape_compute_graph_;
|
||||
const FunctionSchema* schema_;
|
||||
|
|
@ -1058,6 +1057,7 @@ void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
|
|||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
|
||||
AliasDb db(graph);
|
||||
|
|
@ -1076,6 +1076,16 @@ TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
|
|||
calculateSymbolicShapesOnOp(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& inputs) {
|
||||
if (shapeComputeGraphForSchema(*schema) == c10::nullopt) {
|
||||
// Avoid doing all this work for functions that don't have a
|
||||
// supported schema
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
|
||||
return cached_ret_vec;
|
||||
}
|
||||
|
||||
std::vector<SSArgument> ssa_args;
|
||||
for (auto& arg : inputs) {
|
||||
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
|
||||
|
|
@ -1087,7 +1097,11 @@ calculateSymbolicShapesOnOp(
|
|||
}
|
||||
|
||||
auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
|
||||
return op_analyzer.run(ssa_args);
|
||||
auto res = op_analyzer.run(ssa_args);
|
||||
if (res.has_value()) {
|
||||
cache_shape_function(schema, inputs, res.value());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -53,66 +53,5 @@ TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
|
|||
calculateSymbolicShapesOnOp(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& inputs);
|
||||
|
||||
struct TORCH_API CanonicalizedSymbolicShape {
|
||||
CanonicalizedSymbolicShape(
|
||||
c10::SymbolicShape& orig_shape,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map) {
|
||||
init(orig_shape, ss_map);
|
||||
}
|
||||
|
||||
CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
|
||||
std::unordered_map<int64_t, int64_t> new_ssmap;
|
||||
init(orig_shape, new_ssmap);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<std::vector<int64_t>> values_;
|
||||
std::vector<bool> is_symbolic_;
|
||||
|
||||
void init(
|
||||
c10::SymbolicShape& orig_shape,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map) {
|
||||
auto sizes = orig_shape.sizes();
|
||||
if (!sizes) {
|
||||
values_ = c10::nullopt;
|
||||
return;
|
||||
}
|
||||
values_ = std::vector<int64_t>();
|
||||
int64_t cur_symbolic_index = -(int64_t)ss_map.size() - 1;
|
||||
for (auto& cur_shape : *sizes) {
|
||||
if (cur_shape.is_static()) {
|
||||
is_symbolic_.emplace_back(false);
|
||||
values_->push_back(cur_shape.static_size());
|
||||
} else {
|
||||
// Check for aliasing
|
||||
is_symbolic_.emplace_back(true);
|
||||
auto it = ss_map.find(cur_shape.value());
|
||||
|
||||
if (it == ss_map.end()) {
|
||||
values_->push_back(cur_symbolic_index);
|
||||
ss_map.insert({cur_shape.value(), cur_symbolic_index});
|
||||
cur_symbolic_index--;
|
||||
} else {
|
||||
values_->push_back(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
friend bool operator==(
|
||||
const CanonicalizedSymbolicShape& a,
|
||||
const CanonicalizedSymbolicShape& b) {
|
||||
if (a.values_.has_value() != b.values_.has_value()) {
|
||||
return false;
|
||||
}
|
||||
if (!a.values_.has_value()) {
|
||||
return true;
|
||||
}
|
||||
return (
|
||||
a.values_.value() == b.values_.value() &&
|
||||
a.is_symbolic_ == b.is_symbolic_);
|
||||
};
|
||||
};
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
208
torch/csrc/jit/passes/symbolic_shape_cache.cpp
Normal file
208
torch/csrc/jit/passes/symbolic_shape_cache.cpp
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
|
||||
#include <torch/csrc/lazy/core/cache.h>
|
||||
|
||||
// SHAPE CACHINHG CODE
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
|
||||
using CanonicalArgVec = std::vector<CanonicalArg>;
|
||||
using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
|
||||
using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
|
||||
|
||||
CanonicalArgVec cannonicalizeVec(
|
||||
const std::vector<SSAInput>& arg_vec,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map,
|
||||
bool deep_copy = true) {
|
||||
CanonicalArgVec canonical_args;
|
||||
canonical_args.reserve(arg_vec.size());
|
||||
for (auto& arg : arg_vec) {
|
||||
if (const IValue* iv = c10::get_if<IValue>(&arg)) {
|
||||
if (deep_copy) {
|
||||
canonical_args.push_back(iv->deepcopy());
|
||||
} else {
|
||||
canonical_args.push_back(*iv);
|
||||
}
|
||||
} else {
|
||||
auto& ss = c10::get<at::SymbolicShape>(arg);
|
||||
canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
|
||||
}
|
||||
}
|
||||
return canonical_args;
|
||||
}
|
||||
|
||||
std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
|
||||
const std::vector<at::SymbolicShape>& ret_vec,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map) {
|
||||
std::vector<CanonicalizedSymbolicShape> canonical_rets;
|
||||
canonical_rets.reserve(ret_vec.size());
|
||||
for (auto& ss : ret_vec) {
|
||||
canonical_rets.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
|
||||
}
|
||||
return canonical_rets;
|
||||
}
|
||||
|
||||
struct ArgumentsHasher {
|
||||
size_t operator()(const ShapeCacheKey& cacheKey) const {
|
||||
// TODO: ignore arguments that are not used in shape function (not needed
|
||||
// initially)
|
||||
auto& op_name = std::get<0>(cacheKey);
|
||||
auto& arg_vec = std::get<1>(cacheKey);
|
||||
|
||||
size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
|
||||
|
||||
hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
|
||||
for (const CanonicalArg& arg : arg_vec) {
|
||||
size_t cur_arg = 0;
|
||||
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
|
||||
// IValue doesn't hash List (as Python doesn't), so we will do a custom
|
||||
// list hash
|
||||
if (ival->isList()) {
|
||||
TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
|
||||
cur_arg = ival->toListRef().size();
|
||||
for (const IValue& elem_ival : ival->toListRef()) {
|
||||
cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
|
||||
}
|
||||
} else {
|
||||
cur_arg = IValue::hash(ival);
|
||||
}
|
||||
} else {
|
||||
cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
|
||||
}
|
||||
hash_val = at::hash_combine(hash_val, cur_arg);
|
||||
}
|
||||
return hash_val;
|
||||
}
|
||||
};
|
||||
|
||||
using ShapeCache = lazy::Cache<
|
||||
ShapeCacheKey,
|
||||
std::vector<CanonicalizedSymbolicShape>,
|
||||
ArgumentsHasher>;
|
||||
|
||||
constexpr size_t kShapeCacheSize = 1024;
|
||||
ShapeCache shapeCache(kShapeCacheSize);
|
||||
|
||||
ShapeCacheKey get_cache_key(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& arg_vec,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map,
|
||||
bool deep_copy = true) {
|
||||
CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
|
||||
return std::make_tuple(schema->operator_name(), canonical_args);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TORCH_API void cache_shape_function(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& arg_vec,
|
||||
const std::vector<at::SymbolicShape>& ret_vec) {
|
||||
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
|
||||
auto ss_map = std::unordered_map<int64_t, int64_t>();
|
||||
auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
|
||||
auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
|
||||
cannonicalizeVec(ret_vec, ss_map));
|
||||
shapeCache.Add(cache_key, can_ret_vec);
|
||||
}
|
||||
|
||||
TORCH_API c10::optional<std::vector<at::SymbolicShape>>
|
||||
get_cached_shape_function(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& arg_vec) {
|
||||
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
|
||||
// ss_map and inverse_ss_map
|
||||
auto ss_map = std::unordered_map<int64_t, int64_t>();
|
||||
auto cache_key =
|
||||
get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
|
||||
auto cached_ret_vec = shapeCache.Get(cache_key);
|
||||
if (cached_ret_vec == nullptr) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
// Decanonicalize the return values
|
||||
auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
|
||||
for (auto& ss_val : ss_map) {
|
||||
inverse_ss_map[ss_val.second] = ss_val.first;
|
||||
}
|
||||
std::vector<at::SymbolicShape> ret_vec;
|
||||
for (auto& css : *cached_ret_vec) {
|
||||
ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
|
||||
}
|
||||
return ret_vec;
|
||||
}
|
||||
|
||||
// Function only to access the cache, used for testing
|
||||
TORCH_API void clear_shape_cache() {
|
||||
shapeCache.Clear();
|
||||
}
|
||||
|
||||
TORCH_API size_t get_shape_cache_size() {
|
||||
return shapeCache.Numel();
|
||||
}
|
||||
|
||||
void CanonicalizedSymbolicShape::init(
|
||||
const c10::SymbolicShape& orig_shape,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map) {
|
||||
auto sizes = orig_shape.sizes();
|
||||
if (!sizes) {
|
||||
values_ = c10::nullopt;
|
||||
return;
|
||||
}
|
||||
values_ = std::vector<int64_t>();
|
||||
int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
|
||||
for (auto& cur_shape : *sizes) {
|
||||
if (cur_shape.is_static()) {
|
||||
values_->push_back(cur_shape.static_size());
|
||||
} else {
|
||||
// Check for aliasing
|
||||
auto it = ss_map.find(cur_shape.value());
|
||||
|
||||
if (it == ss_map.end()) {
|
||||
values_->push_back(cur_symbolic_index);
|
||||
ss_map.insert({cur_shape.value(), cur_symbolic_index});
|
||||
cur_symbolic_index--;
|
||||
} else {
|
||||
values_->push_back(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
|
||||
std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
|
||||
if (!values_.has_value()) {
|
||||
return c10::SymbolicShape();
|
||||
}
|
||||
std::vector<at::ShapeSymbol> sizes;
|
||||
for (long long cur_val : *values_) {
|
||||
if (cur_val >= 0) {
|
||||
sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
|
||||
continue;
|
||||
}
|
||||
auto res = inverse_ss_map.find(cur_val);
|
||||
if (res != inverse_ss_map.end()) {
|
||||
sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
|
||||
} else {
|
||||
auto new_symbol = at::ShapeSymbol::newSymbol();
|
||||
inverse_ss_map.insert({cur_val, new_symbol.value()});
|
||||
sizes.push_back(new_symbol);
|
||||
}
|
||||
}
|
||||
return c10::SymbolicShape(std::move(sizes));
|
||||
}
|
||||
|
||||
size_t CanonicalizedSymbolicShape::hash() const {
|
||||
if (!values_.has_value()) {
|
||||
return 0x8cc80c80; // random value to prevent hash collisions
|
||||
}
|
||||
return c10::hash<std::vector<int64_t>>()(values_.value());
|
||||
}
|
||||
|
||||
bool operator==(
|
||||
const CanonicalizedSymbolicShape& a,
|
||||
const CanonicalizedSymbolicShape& b) {
|
||||
return a.values_ == b.values_;
|
||||
};
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
57
torch/csrc/jit/passes/symbolic_shape_cache.h
Normal file
57
torch/csrc/jit/passes/symbolic_shape_cache.h
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct TORCH_API CanonicalizedSymbolicShape {
|
||||
// TODO: Consider in the future if it is reasonable to
|
||||
// merge code with SymbolicShape or VaryingShape while keeping
|
||||
// the two not implicitly convertable (and cause bugs).
|
||||
CanonicalizedSymbolicShape(
|
||||
const c10::SymbolicShape& orig_shape,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map) {
|
||||
init(orig_shape, ss_map);
|
||||
}
|
||||
|
||||
CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
|
||||
std::unordered_map<int64_t, int64_t> new_ssmap;
|
||||
init(orig_shape, new_ssmap);
|
||||
}
|
||||
|
||||
size_t hash() const;
|
||||
|
||||
c10::SymbolicShape toSymbolicShape(
|
||||
std::unordered_map<int64_t, int64_t>& inverse_ss_map) const;
|
||||
|
||||
TORCH_API friend bool operator==(
|
||||
const CanonicalizedSymbolicShape& a,
|
||||
const CanonicalizedSymbolicShape& b);
|
||||
|
||||
private:
|
||||
c10::optional<std::vector<int64_t>> values_;
|
||||
|
||||
void init(
|
||||
const c10::SymbolicShape& orig_shape,
|
||||
std::unordered_map<int64_t, int64_t>& ss_map);
|
||||
};
|
||||
|
||||
// SHAPE CACHE API
|
||||
TORCH_API c10::optional<std::vector<at::SymbolicShape>>
|
||||
get_cached_shape_function(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& arg_vec);
|
||||
|
||||
TORCH_API void cache_shape_function(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& arg_vec,
|
||||
const std::vector<at::SymbolicShape>& ret_vec);
|
||||
|
||||
// For use in test code
|
||||
TORCH_API void clear_shape_cache();
|
||||
TORCH_API size_t get_shape_cache_size();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user