mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Picks some safe replacements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139453 Approved by: https://github.com/Skylion007
436 lines
19 KiB
C++
436 lines
19 KiB
C++
#include <ATen/core/dispatch/Dispatcher.h>
|
|
#include <torch/csrc/utils/schema_info.h>
|
|
|
|
namespace torch::utils {
|
|
void SchemaInfo::addArgumentValue(
|
|
const std::string& name,
|
|
const at::IValue& value) {
|
|
std::optional<int> index = schema_.argumentIndexWithName(name);
|
|
TORCH_INTERNAL_ASSERT(
|
|
index != std::nullopt, "Schema has no argument named ", name);
|
|
value_map_[name] = value;
|
|
alias_maps_current_ = false;
|
|
}
|
|
|
|
void SchemaInfo::addArgumentValues(
|
|
const std::vector<std::optional<at::IValue>>& value_list) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
value_list.size() <= schema_.arguments().size(),
|
|
"Schema does not have enough arguments for value list");
|
|
|
|
for (size_t i = 0; i < value_list.size(); i++) {
|
|
if (value_list[i].has_value()) {
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
value_map_[schema_.arguments()[i].name()] = *value_list[i];
|
|
alias_maps_current_ = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
void SchemaInfo::addArgumentValues(
|
|
const std::unordered_map<std::string, at::IValue>& values) {
|
|
for (const auto& key_pair : values) {
|
|
addArgumentValue(key_pair.first, key_pair.second);
|
|
}
|
|
}
|
|
|
|
bool SchemaInfo::hasInputArgumentNamed(const std::string& name) const {
|
|
return std::any_of(
|
|
schema_.arguments().begin(),
|
|
schema_.arguments().end(),
|
|
[&name](const c10::Argument& arg) { return arg.name() == name; });
|
|
}
|
|
|
|
bool SchemaInfo::is_mutable() {
|
|
for (size_t i = 0; i < schema_.arguments().size(); i++) {
|
|
if (is_mutable({c10::SchemaArgType::input, i})) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
argument.index < schema_.getCorrectList(argument.type).size(),
|
|
"Invalid index for schema.");
|
|
if (!alias_maps_current_) {
|
|
generateAliasMaps();
|
|
}
|
|
static const std::vector<SchemaSpecialCasePair> training_ops =
|
|
getTrainingOps();
|
|
const auto& correct_map = (argument.type == c10::SchemaArgType::input)
|
|
? input_alias_map_
|
|
: output_alias_map_;
|
|
// Note that the training_op checks depend on index because
|
|
// of cases where either running_mean or running_var alias another input
|
|
// argument causing its alias status to change.
|
|
return std::any_of(
|
|
correct_map[argument.index].begin(),
|
|
correct_map[argument.index].end(),
|
|
[this](size_t aliasing_index) {
|
|
const auto is_training_op = std::find_if(
|
|
training_ops.begin(),
|
|
training_ops.end(),
|
|
[this](const auto& training_op) {
|
|
return this->schema_ == training_op.first;
|
|
});
|
|
|
|
bool special_case = (is_training_op != training_ops.end()) &&
|
|
is_training_op->second.count(
|
|
this->schema_.arguments()[aliasing_index].name());
|
|
if (special_case) {
|
|
bool has_training = (hasInputArgumentNamed("training") &&
|
|
!value_map_.count("training")) ||
|
|
(value_map_.count("training") &&
|
|
value_map_.at("training").toBool());
|
|
bool has_train =
|
|
(hasInputArgumentNamed("train") && !value_map_.count("train")) ||
|
|
(value_map_.count("train") && value_map_.at("train").toBool());
|
|
bool has_use_input_stats =
|
|
(hasInputArgumentNamed("use_input_stats") &&
|
|
!value_map_.count("use_input_stats")) ||
|
|
(value_map_.count("use_input_stats") &&
|
|
value_map_.at("use_input_stats").toBool());
|
|
return has_training || has_train || has_use_input_stats;
|
|
} else {
|
|
return this->schema_.is_mutable(
|
|
{c10::SchemaArgType::input, aliasing_index});
|
|
}
|
|
});
|
|
}
|
|
|
|
bool SchemaInfo::has_argument(std::string_view name) {
|
|
return schema_.argumentIndexWithName(name) != std::nullopt;
|
|
}
|
|
|
|
bool SchemaInfo::is_mutable(std::string_view name) {
|
|
std::optional<int> index = schema_.argumentIndexWithName(name);
|
|
TORCH_INTERNAL_ASSERT(
|
|
index.has_value(), "Schema has no argument named ", name);
|
|
|
|
return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
|
|
}
|
|
|
|
bool SchemaInfo::is_nondeterministic() const {
|
|
static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
|
|
if (dropout_schema == schema_ && value_map_.count("train") &&
|
|
!value_map_.at("train").toBool()) {
|
|
return false;
|
|
}
|
|
|
|
#if defined C10_MOBILE
|
|
static const std::vector<c10::FunctionSchema> nondeterministic_ops =
|
|
getNonDeterministicOps();
|
|
return std::any_of(
|
|
nondeterministic_ops.begin(),
|
|
nondeterministic_ops.end(),
|
|
[this](const c10 ::FunctionSchema& nondeterministic_op) {
|
|
return nondeterministic_op == this->schema_;
|
|
});
|
|
#else
|
|
const auto& op = c10::Dispatcher::singleton().findOp(
|
|
c10::OperatorName(schema_.name(), schema_.overload_name()));
|
|
return op && op->hasTag(at::Tag::nondeterministic_seeded);
|
|
#endif
|
|
}
|
|
|
|
bool SchemaInfo::may_alias(
|
|
const c10::SchemaArgument& lhs,
|
|
const c10::SchemaArgument& rhs) {
|
|
bool basic_check = schema_.may_alias(lhs, rhs);
|
|
if (basic_check) {
|
|
return true;
|
|
}
|
|
std::optional<c10::AliasTypeSet> lhsAliasTypeSet =
|
|
schema_.mapTypeToAliasTypeSet(
|
|
schema_.getCorrectList(lhs.type)[lhs.index].type());
|
|
std::optional<c10::AliasTypeSet> rhsAliasTypeSet =
|
|
schema_.mapTypeToAliasTypeSet(
|
|
schema_.getCorrectList(rhs.type)[rhs.index].type());
|
|
bool types_can_alias =
|
|
schema_.canAliasTypeSetsAlias(lhsAliasTypeSet, rhsAliasTypeSet);
|
|
if (!types_can_alias) {
|
|
return false;
|
|
}
|
|
|
|
if (!alias_maps_current_) {
|
|
generateAliasMaps();
|
|
}
|
|
bool wildcard_alias_check =
|
|
wildcardSet().count(lhs) && wildcardSet().count(rhs);
|
|
if (wildcard_alias_check) {
|
|
return true;
|
|
}
|
|
|
|
if (lhs.type == c10::SchemaArgType::input &&
|
|
rhs.type == c10::SchemaArgType::input) {
|
|
return input_alias_map_[lhs.index].count(rhs.index);
|
|
} else if (
|
|
lhs.type == c10::SchemaArgType::output &&
|
|
rhs.type == c10::SchemaArgType::output) {
|
|
for (size_t lhs_alias_input : output_alias_map_[lhs.index]) {
|
|
if (output_alias_map_[rhs.index].count(lhs_alias_input)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
} else if (lhs.type == c10::SchemaArgType::output) {
|
|
return output_alias_map_[lhs.index].count(rhs.index);
|
|
} else {
|
|
return output_alias_map_[rhs.index].count(lhs.index);
|
|
}
|
|
}
|
|
|
|
bool SchemaInfo::may_contain_alias(
|
|
const c10::SchemaArgument& lhs,
|
|
const c10::SchemaArgument& rhs,
|
|
bool bidirectional) {
|
|
bool basic_check = schema_.may_contain_alias(lhs, rhs) || may_alias(lhs, rhs);
|
|
if (basic_check) {
|
|
return true;
|
|
}
|
|
if (!alias_maps_current_) {
|
|
generateAliasMaps();
|
|
}
|
|
if (bidirectional) {
|
|
return mayContainAliasImpl(lhs, rhs) || mayContainAliasImpl(rhs, lhs);
|
|
} else {
|
|
return mayContainAliasImpl(lhs, rhs);
|
|
}
|
|
}
|
|
|
|
bool SchemaInfo::mayContainAliasImpl(
|
|
const c10::SchemaArgument& lhs,
|
|
const c10::SchemaArgument& rhs) {
|
|
std::optional<c10::AliasTypeSet> lhsContainedAliasTypeSet =
|
|
schema_.getAliasTypeSetContainedTypes(schema_.mapTypeToAliasTypeSet(
|
|
schema_.getCorrectList(lhs.type)[lhs.index].type()));
|
|
std::optional<c10::AliasTypeSet> rhsAliasTypeSet =
|
|
schema_.mapTypeToAliasTypeSet(
|
|
schema_.getCorrectList(rhs.type)[rhs.index].type());
|
|
bool types_can_alias =
|
|
schema_.canAliasTypeSetsAlias(lhsContainedAliasTypeSet, rhsAliasTypeSet);
|
|
return types_can_alias && containerSet().count(lhs) &&
|
|
wildcardSet().count(rhs);
|
|
}
|
|
|
|
void SchemaInfo::ensureConservativity(
|
|
const std::unordered_set<at::Symbol>& duplicates,
|
|
const std::vector<c10::Argument>& arguments_list,
|
|
c10::SchemaArgType type) {
|
|
for (size_t i = 0; i < arguments_list.size(); i++) {
|
|
if (arguments_list[i].alias_info()) {
|
|
for (const auto& set : arguments_list[i].alias_info()->afterSets()) {
|
|
if (duplicates.count(set)) {
|
|
wildcard_set_.insert({type, i});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
|
|
// This list of nondeterministic ops is copied from JIT ir.cpp.
|
|
static const std::vector<std::string> nondeterministic_op_strings = {
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
|
|
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
|
|
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
|
|
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
|
|
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
|
|
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
|
|
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
|
|
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
|
|
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
|
|
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
|
|
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
|
|
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
|
|
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
|
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
|
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
|
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
|
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
|
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
|
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
|
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
|
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
|
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
|
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
|
|
|
|
std::vector<c10::FunctionSchema> nondeterministic_ops;
|
|
nondeterministic_ops.reserve(nondeterministic_op_strings.size());
|
|
for (const std::string& signature : nondeterministic_op_strings) {
|
|
nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature));
|
|
}
|
|
|
|
return nondeterministic_ops;
|
|
}
|
|
|
|
std::vector<SchemaSpecialCasePair> SchemaInfo::getTrainingOps() {
|
|
// This is a list of pairs of ops to sets of strings
|
|
// where the a boolean variable (either "training",
|
|
// "train" or "use_input_stats") affects the mutability
|
|
// of the unorderered set of strings.
|
|
static const std::vector<std::pair<std::string, std::unordered_set<std::string>>> training_op_pairs =
|
|
{{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))",
|
|
{"running_mean", "running_var"}},
|
|
{"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor",
|
|
{"noise"}},
|
|
{"aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)",
|
|
{"noise"}},
|
|
{"rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)",
|
|
{"noise"}}};
|
|
|
|
std::vector<SchemaSpecialCasePair> training_ops;
|
|
training_ops.reserve(training_op_pairs.size());
|
|
for (const auto& signature : training_op_pairs) {
|
|
training_ops.emplace_back(
|
|
torch::jit::parseSchema(signature.first), signature.second);
|
|
}
|
|
|
|
return training_ops;
|
|
}
|
|
|
|
void SchemaInfo::initSchemaInfo() {
|
|
if (has_init_) {
|
|
return;
|
|
}
|
|
has_init_ = true;
|
|
|
|
std::unordered_set<at::Symbol> duplicates;
|
|
auto init_schema_arguments = [this, &duplicates](
|
|
const std::vector<c10::Argument>&
|
|
arguments_list,
|
|
c10::SchemaArgType type) {
|
|
std::unordered_set<at::Symbol> seen;
|
|
for (size_t i = 0; i < arguments_list.size(); i++) {
|
|
const c10::Argument& argument = arguments_list[i];
|
|
if (argument.alias_info()) {
|
|
if (argument.alias_info()->isWildcardAfter()) {
|
|
wildcard_set_.insert({type, i});
|
|
} else {
|
|
// This check is to ensure that the FunctionSchema will accurately
|
|
// be represented when calling may_alias and may_contain_alias
|
|
// on schemas with more than one argument within arguments_list that
|
|
// shares an alias set.
|
|
for (const auto& set : argument.alias_info()->afterSets()) {
|
|
if (seen.count(set)) {
|
|
TORCH_WARN(
|
|
set.toQualString(),
|
|
" appears twice in same argument list which will make aliasing checks more conservative.");
|
|
duplicates.insert(set);
|
|
} else {
|
|
seen.insert(set);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
std::optional<c10::AliasTypeSet> contained_types =
|
|
schema_.getAliasTypeSetContainedTypes(
|
|
schema_.mapTypeToAliasTypeSet(argument.type()));
|
|
if (contained_types && !contained_types->empty()) {
|
|
container_set_.insert({type, i});
|
|
}
|
|
}
|
|
};
|
|
|
|
init_schema_arguments(schema_.arguments(), c10::SchemaArgType::input);
|
|
init_schema_arguments(schema_.returns(), c10::SchemaArgType::output);
|
|
ensureConservativity(
|
|
duplicates, schema_.arguments(), c10::SchemaArgType::input);
|
|
ensureConservativity(
|
|
duplicates, schema_.returns(), c10::SchemaArgType::output);
|
|
}
|
|
|
|
const std::unordered_set<c10::SchemaArgument>& SchemaInfo::wildcardSet() {
|
|
initSchemaInfo();
|
|
return wildcard_set_;
|
|
}
|
|
|
|
const std::unordered_set<c10::SchemaArgument>& SchemaInfo::containerSet() {
|
|
initSchemaInfo();
|
|
return container_set_;
|
|
}
|
|
|
|
void SchemaInfo::generateAliasMaps() {
|
|
initSchemaInfo();
|
|
|
|
alias_maps_current_ = true;
|
|
input_alias_map_ = std::vector<std::unordered_set<size_t>>(
|
|
schema_.arguments().size(), std::unordered_set<size_t>());
|
|
output_alias_map_ = std::vector<std::unordered_set<size_t>>(
|
|
schema_.returns().size(), std::unordered_set<size_t>());
|
|
|
|
// Fills input_alias_map_
|
|
for (size_t i = 0; i < schema_.arguments().size(); i++) {
|
|
for (size_t j = i; j < schema_.arguments().size(); j++) {
|
|
if (i == j) {
|
|
input_alias_map_[i].insert(i);
|
|
} else if (
|
|
value_map_.count(schema_.arguments()[i].name()) &&
|
|
value_map_.count(schema_.arguments()[j].name())) {
|
|
if (value_map_[schema_.arguments()[i].name()].isAliasOf(
|
|
value_map_[schema_.arguments()[j].name()])) {
|
|
input_alias_map_[i].insert(j);
|
|
input_alias_map_[j].insert(i);
|
|
if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
|
|
wildcard_set_.insert({c10::SchemaArgType::input, j});
|
|
} else if (wildcard_set_.count({c10::SchemaArgType::input, j})) {
|
|
wildcard_set_.insert({c10::SchemaArgType::input, i});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fills wildcard_set with container created wildcards.
|
|
// For instance, given the schema:
|
|
// test(Tensor a, Tensor(*) b, Tensor[] c) -> Tensor
|
|
// where value(a) is contained in value(c), then a will be added to the
|
|
// wildcard set where it can now alias b.
|
|
for (size_t i = 0; i < schema_.arguments().size(); i++) {
|
|
for (size_t j = 0; j < schema_.arguments().size(); j++) {
|
|
// if they are already aliasing, there is no way one contains the other
|
|
if (!input_alias_map_[i].count(j) &&
|
|
value_map_.count(schema_.arguments()[i].name()) &&
|
|
value_map_.count(schema_.arguments()[j].name())) {
|
|
c10::IValue::HashAliasedIValues subValues;
|
|
value_map_[schema_.arguments()[i].name()].getSubValues(subValues);
|
|
if (subValues.count(value_map_[schema_.arguments()[j].name()])) {
|
|
wildcard_set_.insert({c10::SchemaArgType::input, j});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fills output_alias_map_
|
|
for (size_t i = 0; i < schema_.arguments().size(); i++) {
|
|
for (size_t j = 0; j < schema_.returns().size(); j++) {
|
|
if (schema_.may_alias(
|
|
{c10::SchemaArgType::input, i},
|
|
{c10::SchemaArgType::output, j})) {
|
|
if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
|
|
wildcard_set_.insert({c10::SchemaArgType::output, j});
|
|
}
|
|
output_alias_map_[j].insert(
|
|
input_alias_map_[i].begin(), input_alias_map_[i].end());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace torch::utils
|