mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
- Create c10::AliasTypeSet type def of vector<TypePtr> to match alias_analysis.cpp formatting and improve readability. - Move canAliasTypeSetsAlias, mapTypeToAliasTypeSet, getAliasTypeSetContainedTypes, and getCorrectList to public in function_schema.h for use in SchemaInfo class. **In the future it might be better to find a different home for most of these functions since they don't depend on functionSchema. ** - Created hash function for SchemaArgument - Add assert to ensure that there is only 1 input and 1 output with each alias set (excluding wildcard) - Fixed double wildcard input edge case for may_alias. (This is the case where if there is a schema with the form (Tensor(a) a, Tensor(*) b, Tensor(*) c) -> Tensor, and the argument values for 'a' and 'b' cause them to alias, then 'a' may also alias 'c'. - Added tests for double wildcard case in may_alias, mismatching types in may_alias, and the uniqueness internal assert. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81439 Approved by: https://github.com/davidberard98
282 lines
12 KiB
C++
282 lines
12 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/utils/schema_info.h>
|
|
|
|
namespace torch {
|
|
namespace utils {
|
|
using c10::SchemaArgType;
|
|
|
|
TEST(SchemaInfoHasSideEffectsTest, Basic) {
|
|
SchemaInfo no_side_effects_schema_info(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
SchemaInfo side_effects_schema_info(
|
|
"aten::warn(str message, int stacklevel=2) -> ()");
|
|
ASSERT_TRUE(side_effects_schema_info.has_side_effects());
|
|
ASSERT_FALSE(no_side_effects_schema_info.has_side_effects());
|
|
}
|
|
|
|
TEST(FunctionSchemaIsMutableTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
ASSERT_FALSE(schema.is_mutable(2));
|
|
ASSERT_FALSE(schema.is_mutable("alpha"));
|
|
}
|
|
|
|
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
|
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, Basic) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
ASSERT_FALSE(schema.is_mutable(2));
|
|
ASSERT_FALSE(schema.is_mutable("alpha"));
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, InvalidArgument) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
|
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, AliasingInputs) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("self", input);
|
|
schema.addArgumentValue("other", input);
|
|
ASSERT_TRUE(schema.is_mutable(1));
|
|
ASSERT_TRUE(schema.is_mutable("other"));
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, InstanceNorm) {
|
|
SchemaInfo schema_info(
|
|
"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor");
|
|
ASSERT_FALSE(schema_info.is_mutable("running_mean"));
|
|
ASSERT_FALSE(schema_info.is_mutable("running_var"));
|
|
schema_info.addArgumentValue("use_input_stats", true);
|
|
ASSERT_TRUE(schema_info.is_mutable("running_mean"));
|
|
ASSERT_TRUE(schema_info.is_mutable("running_var"));
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, BatchNorm) {
|
|
SchemaInfo schema_info(
|
|
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor");
|
|
ASSERT_FALSE(schema_info.is_mutable("running_mean"));
|
|
ASSERT_FALSE(schema_info.is_mutable("running_var"));
|
|
schema_info.addArgumentValue("training", true);
|
|
ASSERT_TRUE(schema_info.is_mutable("running_mean"));
|
|
ASSERT_TRUE(schema_info.is_mutable("running_var"));
|
|
}
|
|
|
|
TEST(SchemaInfoIsNonDeterministicTest, Basic) {
|
|
SchemaInfo deterministic_schema_info(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
SchemaInfo nondeterministic_schema_info(
|
|
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor");
|
|
ASSERT_FALSE(deterministic_schema_info.is_nondeterministic());
|
|
ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic());
|
|
}
|
|
|
|
TEST(SchemaInfoIsNonDeterministicTest, Dropout) {
|
|
SchemaInfo droupout_schema_info(
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
|
|
ASSERT_TRUE(droupout_schema_info.is_nondeterministic());
|
|
droupout_schema_info.addArgumentValue("train", false);
|
|
ASSERT_FALSE(droupout_schema_info.is_nondeterministic());
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(
|
|
schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}),
|
|
c10::Error);
|
|
ASSERT_THROW(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}),
|
|
c10::Error);
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Wildcard) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, AliasingInputs) {
|
|
SchemaInfo schema(
|
|
"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("self", input);
|
|
schema.addArgumentValue("other", input);
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, AliasingOutputs) {
|
|
SchemaInfo schema(
|
|
"aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("min", input);
|
|
schema.addArgumentValue("max", input);
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, AliasingInputOutput) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("self", input);
|
|
schema.addArgumentValue("other", input);
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) {
|
|
SchemaInfo schema(
|
|
"aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("a", input);
|
|
schema.addArgumentValue("b", input);
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) {
|
|
SchemaInfo schema(
|
|
"aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) {
|
|
SchemaInfo schema(
|
|
"aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))");
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
|
|
ASSERT_TRUE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
|
|
}
|
|
|
|
TEST(SchemaInfoMayAliasTest, MismatchingTypes) {
|
|
SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)");
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayContainAliasTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_contain_alias(
|
|
{SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_contain_alias(
|
|
{SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
|
|
ASSERT_FALSE(schema.may_contain_alias(
|
|
{SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
|
|
c10::FunctionSchema schema =
|
|
torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
|
|
ASSERT_FALSE(
|
|
schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
|
|
ASSERT_TRUE(schema.may_contain_alias(
|
|
{SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
|
|
}
|
|
} // namespace utils
|
|
} // namespace torch
|