Relanding shape cache (75400) (#75710)

Summary:
https://github.com/pytorch/pytorch/pull/75400

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75710

Reviewed By: malfet

Differential Revision: D35598920

Pulled By: Krovatkin

fbshipit-source-id: 2bbbb3d0c24214b5dbb4ca605e7daa94671f96b0
(cherry picked from commit 572f2f9df5bfd73cd7b83536f619bc86d820ccd8)
This commit is contained in:
Nikolay Korovaiko 2022-04-13 00:24:52 -07:00 committed by PyTorch MergeBot
parent db1801099b
commit ce842f43f2
7 changed files with 434 additions and 80 deletions

View File

@ -467,6 +467,14 @@ struct TORCH_API SymbolicShape {
// result will be unranked.
SymbolicShape merge(const SymbolicShape& other) const;
friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
return lhs.dims_ == rhs.dims_;
}
friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
return !(lhs == rhs);
}
private:
c10::optional<std::vector<ShapeSymbol>> dims_;
};

View File

@ -8,6 +8,8 @@
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
@ -293,27 +295,34 @@ TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
namespace {
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
auto a_canonical = CanonicalizedSymbolicShape(a);
auto e_canonical = CanonicalizedSymbolicShape(e);
EXPECT_EQ(a_canonical, e_canonical);
}
void assertShapeEqual(
c10::optional<std::vector<c10::SymbolicShape>>& actual,
std::vector<c10::optional<int64_t>> expected) {
ASSERT_TRUE(actual.has_value());
ASSERT_EQ(actual->size(), 1);
auto a_canonical = CanonicalizedSymbolicShape(actual->at(0));
auto symb_expected = c10::SymbolicShape(expected);
auto b_canonical = CanonicalizedSymbolicShape(symb_expected);
ASSERT_EQ(a_canonical, b_canonical);
assertShapeEqual(actual->at(0), symb_expected);
}
const FunctionSchema* getSchema(const char* name) {
return &(getOperatorForLiteral(name)->schema());
}
} // namespace
TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
// Figure out how to fetch a function schema
// Ask someone else how to create a function schema / operator in C++
std::shared_ptr<Operator> op = getOperatorForLiteral(
auto schema = getSchema(
"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
const FunctionSchema* schema = &(op->schema());
c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
@ -352,5 +361,123 @@ TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
}
TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
clear_shape_cache();
auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
assertShapeEqual(res, {sym_dim, 56});
auto res1_val = res->at(0);
// The exact same arguments should return the exact same result
res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
auto res2_val = res->at(0);
EXPECT_EQ(res1_val, res2_val);
EXPECT_EQ(get_shape_cache_size(), 1);
// Same shape but different symbols should return same shape
// but different symbolic indicies
res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
auto res3_val = res->at(0);
assertShapeEqual(res3_val, res2_val);
EXPECT_NE(res3_val, res2_val);
EXPECT_EQ(get_shape_cache_size(), 1);
// Different concrete shape should be cached separately
res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
assertShapeEqual(res, {sym_dim, 20});
EXPECT_EQ(get_shape_cache_size(), 2);
res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
assertShapeEqual(res, {sym_dim, 20});
EXPECT_EQ(get_shape_cache_size(), 3);
res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
assertShapeEqual(res, {sym_dim, sym_dim});
EXPECT_EQ(get_shape_cache_size(), 4);
}
TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
clear_shape_cache();
auto squeeze_op =
getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
auto mul_tensor =
getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
auto mul_scalar =
getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
auto div_tensor =
getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
c10::IValue const_int = 1;
c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
assertShapeEqual(res, {sym_dim, 64});
// Show that cache can handle multiple functions
res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
assertShapeEqual(res, {sym_dim, 64});
EXPECT_EQ(get_shape_cache_size(), 2);
res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
assertShapeEqual(res, {sym_dim, 64});
EXPECT_EQ(get_shape_cache_size(), 3);
// Even when the expected outcome is the same, should not collide
res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
assertShapeEqual(res, {sym_dim, 64});
EXPECT_EQ(get_shape_cache_size(), 4);
// Don't lose cached objects
res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
assertShapeEqual(res, {sym_dim, 64});
EXPECT_EQ(get_shape_cache_size(), 4);
res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
// SSA can infer that sym_dim is 64 as both tensors
// use the same sym_dim
assertShapeEqual(res, {64, 64});
EXPECT_EQ(get_shape_cache_size(), 5);
}
TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
clear_shape_cache();
auto max_dim_op = getSchema(
"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
c10::IValue const_int = 1;
c10::IValue false_ival = false;
c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
auto res =
calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim});
assertShapeEqual(res->at(0), expected_res);
// res0 and res1 should share the same symbolic symbol
EXPECT_EQ(res->at(0), res->at(1));
// Also test that the shape cache also returns consistent result shapes
res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
assertShapeEqual(res->at(0), expected_res);
EXPECT_EQ(res->at(0), res->at(1));
EXPECT_EQ(get_shape_cache_size(), 1);
}
} // namespace jit
} // namespace torch

View File

@ -306,6 +306,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/passes/integer_value_refinement.cpp",
"torch/csrc/jit/passes/replacement_of_old_operators.cpp",
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
"torch/csrc/jit/passes/symbolic_shape_cache.cpp",
"torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp",
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",

View File

@ -19,6 +19,7 @@
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
@ -174,6 +175,17 @@ bool symbolicShapeAnalysisTestModeEnabled() {
return symbolic_shape_analysis_test_mode;
}
using SSArgument = c10::variant<ShapeArguments, IValue>;
std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
out << *iv;
} else {
out << c10::get<ShapeArguments>(sa);
}
return out;
}
namespace {
bool isListOfInts(const TypePtr& type) {
@ -244,8 +256,6 @@ c10::SymbolicShape extractListShape(
return c10::SymbolicShape(output_shape);
}
} // namespace
// Symbolic Shape Analysis works through iteratively partially evaluating
// a TorchScript shape compute graph by inputing properties from input
// Tensors. We can substitute in properties like `len(x)` and `x[1]`
@ -260,17 +270,6 @@ c10::SymbolicShape extractListShape(
// means that we do know its concrete value statically but we can asssign sets
// of tensor dimensions which must be equal at runtime.
using SSArgument = c10::variant<ShapeArguments, IValue>;
std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
out << *iv;
} else {
out << c10::get<ShapeArguments>(sa);
}
return out;
}
struct SymbolicShapeOpAnalyzer {
std::shared_ptr<Graph> shape_compute_graph_;
const FunctionSchema* schema_;
@ -1058,6 +1057,7 @@ void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
}
}
}
} // namespace
void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
AliasDb db(graph);
@ -1076,6 +1076,16 @@ TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs) {
if (shapeComputeGraphForSchema(*schema) == c10::nullopt) {
// Avoid doing all this work for functions that don't have a
// supported schema
return c10::nullopt;
}
if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
return cached_ret_vec;
}
std::vector<SSArgument> ssa_args;
for (auto& arg : inputs) {
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
@ -1087,7 +1097,11 @@ calculateSymbolicShapesOnOp(
}
auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
return op_analyzer.run(ssa_args);
auto res = op_analyzer.run(ssa_args);
if (res.has_value()) {
cache_shape_function(schema, inputs, res.value());
}
return res;
}
} // namespace jit

View File

@ -53,66 +53,5 @@ TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs);
struct TORCH_API CanonicalizedSymbolicShape {
CanonicalizedSymbolicShape(
c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map) {
init(orig_shape, ss_map);
}
CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
std::unordered_map<int64_t, int64_t> new_ssmap;
init(orig_shape, new_ssmap);
}
private:
c10::optional<std::vector<int64_t>> values_;
std::vector<bool> is_symbolic_;
void init(
c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map) {
auto sizes = orig_shape.sizes();
if (!sizes) {
values_ = c10::nullopt;
return;
}
values_ = std::vector<int64_t>();
int64_t cur_symbolic_index = -(int64_t)ss_map.size() - 1;
for (auto& cur_shape : *sizes) {
if (cur_shape.is_static()) {
is_symbolic_.emplace_back(false);
values_->push_back(cur_shape.static_size());
} else {
// Check for aliasing
is_symbolic_.emplace_back(true);
auto it = ss_map.find(cur_shape.value());
if (it == ss_map.end()) {
values_->push_back(cur_symbolic_index);
ss_map.insert({cur_shape.value(), cur_symbolic_index});
cur_symbolic_index--;
} else {
values_->push_back(it->second);
}
}
}
}
friend bool operator==(
const CanonicalizedSymbolicShape& a,
const CanonicalizedSymbolicShape& b) {
if (a.values_.has_value() != b.values_.has_value()) {
return false;
}
if (!a.values_.has_value()) {
return true;
}
return (
a.values_.value() == b.values_.value() &&
a.is_symbolic_ == b.is_symbolic_);
};
};
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,208 @@
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/lazy/core/cache.h>
// SHAPE CACHINHG CODE
namespace torch {
namespace jit {
namespace {
using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
using CanonicalArgVec = std::vector<CanonicalArg>;
using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
CanonicalArgVec cannonicalizeVec(
const std::vector<SSAInput>& arg_vec,
std::unordered_map<int64_t, int64_t>& ss_map,
bool deep_copy = true) {
CanonicalArgVec canonical_args;
canonical_args.reserve(arg_vec.size());
for (auto& arg : arg_vec) {
if (const IValue* iv = c10::get_if<IValue>(&arg)) {
if (deep_copy) {
canonical_args.push_back(iv->deepcopy());
} else {
canonical_args.push_back(*iv);
}
} else {
auto& ss = c10::get<at::SymbolicShape>(arg);
canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
}
}
return canonical_args;
}
std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
const std::vector<at::SymbolicShape>& ret_vec,
std::unordered_map<int64_t, int64_t>& ss_map) {
std::vector<CanonicalizedSymbolicShape> canonical_rets;
canonical_rets.reserve(ret_vec.size());
for (auto& ss : ret_vec) {
canonical_rets.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
}
return canonical_rets;
}
struct ArgumentsHasher {
size_t operator()(const ShapeCacheKey& cacheKey) const {
// TODO: ignore arguments that are not used in shape function (not needed
// initially)
auto& op_name = std::get<0>(cacheKey);
auto& arg_vec = std::get<1>(cacheKey);
size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
for (const CanonicalArg& arg : arg_vec) {
size_t cur_arg = 0;
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
// IValue doesn't hash List (as Python doesn't), so we will do a custom
// list hash
if (ival->isList()) {
TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
cur_arg = ival->toListRef().size();
for (const IValue& elem_ival : ival->toListRef()) {
cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
}
} else {
cur_arg = IValue::hash(ival);
}
} else {
cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
}
hash_val = at::hash_combine(hash_val, cur_arg);
}
return hash_val;
}
};
using ShapeCache = lazy::Cache<
ShapeCacheKey,
std::vector<CanonicalizedSymbolicShape>,
ArgumentsHasher>;
constexpr size_t kShapeCacheSize = 1024;
ShapeCache shapeCache(kShapeCacheSize);
ShapeCacheKey get_cache_key(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec,
std::unordered_map<int64_t, int64_t>& ss_map,
bool deep_copy = true) {
CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
return std::make_tuple(schema->operator_name(), canonical_args);
}
} // namespace
TORCH_API void cache_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec,
const std::vector<at::SymbolicShape>& ret_vec) {
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
auto ss_map = std::unordered_map<int64_t, int64_t>();
auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
cannonicalizeVec(ret_vec, ss_map));
shapeCache.Add(cache_key, can_ret_vec);
}
TORCH_API c10::optional<std::vector<at::SymbolicShape>>
get_cached_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec) {
// TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
// ss_map and inverse_ss_map
auto ss_map = std::unordered_map<int64_t, int64_t>();
auto cache_key =
get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
auto cached_ret_vec = shapeCache.Get(cache_key);
if (cached_ret_vec == nullptr) {
return c10::nullopt;
}
// Decanonicalize the return values
auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
for (auto& ss_val : ss_map) {
inverse_ss_map[ss_val.second] = ss_val.first;
}
std::vector<at::SymbolicShape> ret_vec;
for (auto& css : *cached_ret_vec) {
ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
}
return ret_vec;
}
// Function only to access the cache, used for testing
TORCH_API void clear_shape_cache() {
shapeCache.Clear();
}
TORCH_API size_t get_shape_cache_size() {
return shapeCache.Numel();
}
void CanonicalizedSymbolicShape::init(
const c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map) {
auto sizes = orig_shape.sizes();
if (!sizes) {
values_ = c10::nullopt;
return;
}
values_ = std::vector<int64_t>();
int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
for (auto& cur_shape : *sizes) {
if (cur_shape.is_static()) {
values_->push_back(cur_shape.static_size());
} else {
// Check for aliasing
auto it = ss_map.find(cur_shape.value());
if (it == ss_map.end()) {
values_->push_back(cur_symbolic_index);
ss_map.insert({cur_shape.value(), cur_symbolic_index});
cur_symbolic_index--;
} else {
values_->push_back(it->second);
}
}
}
}
c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
if (!values_.has_value()) {
return c10::SymbolicShape();
}
std::vector<at::ShapeSymbol> sizes;
for (long long cur_val : *values_) {
if (cur_val >= 0) {
sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
continue;
}
auto res = inverse_ss_map.find(cur_val);
if (res != inverse_ss_map.end()) {
sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
} else {
auto new_symbol = at::ShapeSymbol::newSymbol();
inverse_ss_map.insert({cur_val, new_symbol.value()});
sizes.push_back(new_symbol);
}
}
return c10::SymbolicShape(std::move(sizes));
}
size_t CanonicalizedSymbolicShape::hash() const {
if (!values_.has_value()) {
return 0x8cc80c80; // random value to prevent hash collisions
}
return c10::hash<std::vector<int64_t>>()(values_.value());
}
bool operator==(
const CanonicalizedSymbolicShape& a,
const CanonicalizedSymbolicShape& b) {
return a.values_ == b.values_;
};
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,57 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
namespace torch {
namespace jit {
struct TORCH_API CanonicalizedSymbolicShape {
// TODO: Consider in the future if it is reasonable to
// merge code with SymbolicShape or VaryingShape while keeping
// the two not implicitly convertable (and cause bugs).
CanonicalizedSymbolicShape(
const c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map) {
init(orig_shape, ss_map);
}
CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
std::unordered_map<int64_t, int64_t> new_ssmap;
init(orig_shape, new_ssmap);
}
size_t hash() const;
c10::SymbolicShape toSymbolicShape(
std::unordered_map<int64_t, int64_t>& inverse_ss_map) const;
TORCH_API friend bool operator==(
const CanonicalizedSymbolicShape& a,
const CanonicalizedSymbolicShape& b);
private:
c10::optional<std::vector<int64_t>> values_;
void init(
const c10::SymbolicShape& orig_shape,
std::unordered_map<int64_t, int64_t>& ss_map);
};
// SHAPE CACHE API
TORCH_API c10::optional<std::vector<at::SymbolicShape>>
get_cached_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec);
TORCH_API void cache_shape_function(
const FunctionSchema* schema,
const std::vector<SSAInput>& arg_vec,
const std::vector<at::SymbolicShape>& ret_vec);
// For use in test code
TORCH_API void clear_shape_cache();
TORCH_API size_t get_shape_cache_size();
} // namespace jit
} // namespace torch