pytorch/test/cpp/jit/test_schema_info.cpp
goldenxuett aa61fdb667 [JIT] Add argumentValue functions and is_mutable checks to SchemaInfo (#80972)
- Created addArgumentValue/s methods in SchemaInfo to pass argument values into the subclass. These are used for more accurate mutation, aliasing and determinism checks which include special cases.
- Added input_alias_map_ to keep track of which inputs alias each other. This is updated with the method generateAliasMap.
- Implemented is_mutable methods in SchemaInfo which also give information based on argument values. For instance, if two inputs alias and one is mutable by the schema, then the other will also be mutable.
- Tested Schema Info is_mutable implementation where inputs alias as mentioned above.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80972
Approved by: https://github.com/davidberard98
2022-07-13 00:16:41 +00:00

91 lines
3.5 KiB
C++

#include <gtest/gtest.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/utils/schema_info.h>
namespace torch {
namespace utils {
TEST(FunctionSchemaIsMutableTest, Basic) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_TRUE(schema.is_mutable(0));
ASSERT_TRUE(schema.is_mutable("self"));
ASSERT_FALSE(schema.is_mutable(1));
ASSERT_FALSE(schema.is_mutable("other"));
ASSERT_FALSE(schema.is_mutable(2));
ASSERT_FALSE(schema.is_mutable("alpha"));
}
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_THROW(schema.is_mutable(4), c10::Error);
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
}
TEST(SchemaInfoIsMutableTest, Basic) {
SchemaInfo schema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_TRUE(schema.is_mutable(0));
ASSERT_TRUE(schema.is_mutable("self"));
ASSERT_FALSE(schema.is_mutable(1));
ASSERT_FALSE(schema.is_mutable("other"));
ASSERT_FALSE(schema.is_mutable(2));
ASSERT_FALSE(schema.is_mutable("alpha"));
}
TEST(SchemaInfoIsMutableTest, InvalidArgument) {
SchemaInfo schema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_THROW(schema.is_mutable(4), c10::Error);
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
}
TEST(SchemaInfoIsMutableTest, AliasingInputs) {
SchemaInfo schema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_TRUE(schema.is_mutable(0));
ASSERT_TRUE(schema.is_mutable("self"));
ASSERT_FALSE(schema.is_mutable(1));
ASSERT_FALSE(schema.is_mutable("other"));
at::Tensor input = at::randn({3, 3});
schema.addArgumentValue("self", input);
schema.addArgumentValue("other", input);
ASSERT_TRUE(schema.is_mutable(1));
ASSERT_TRUE(schema.is_mutable("other"));
}
TEST(FunctionSchemaMayAliasTest, Basic) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_TRUE(schema.may_alias(
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}));
ASSERT_FALSE(schema.may_alias(
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
ASSERT_FALSE(schema.may_alias(
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::input, 0}));
}
TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_THROW(
schema.may_alias(
{c10::SchemaArgType::input, 15}, {c10::SchemaArgType::output, 0}),
c10::Error);
ASSERT_THROW(
schema.may_alias(
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 15}),
c10::Error);
}
TEST(FunctionSchemaMayAliasTest, Wildcard) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
ASSERT_TRUE(schema.may_alias(
{c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
ASSERT_FALSE(schema.may_alias(
{c10::SchemaArgType::output, 1}, {c10::SchemaArgType::input, 0}));
}
} // namespace utils
} // namespace torch