mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
- Created addArgumentValue/s methods in SchemaInfo to pass argument values into the subclass. These are used for more accurate mutation, aliasing and determinism checks which include special cases. - Added input_alias_map_ to keep track of which inputs alias each other. This is updated with the method generateAliasMap. - Implemented is_mutable methods in SchemaInfo which also give information based on argument values. For instance, if two inputs alias and one is mutable by the schema, then the other will also be mutable. - Tested Schema Info is_mutable implementation where inputs alias as mentioned above. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80972 Approved by: https://github.com/davidberard98
91 lines
3.5 KiB
C++
91 lines
3.5 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/utils/schema_info.h>
|
|
|
|
namespace torch {
|
|
namespace utils {
|
|
TEST(FunctionSchemaIsMutableTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
ASSERT_FALSE(schema.is_mutable(2));
|
|
ASSERT_FALSE(schema.is_mutable("alpha"));
|
|
}
|
|
|
|
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
|
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, Basic) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
ASSERT_FALSE(schema.is_mutable(2));
|
|
ASSERT_FALSE(schema.is_mutable("alpha"));
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, InvalidArgument) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
|
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
|
}
|
|
|
|
TEST(SchemaInfoIsMutableTest, AliasingInputs) {
|
|
SchemaInfo schema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
at::Tensor input = at::randn({3, 3});
|
|
schema.addArgumentValue("self", input);
|
|
schema.addArgumentValue("other", input);
|
|
ASSERT_TRUE(schema.is_mutable(1));
|
|
ASSERT_TRUE(schema.is_mutable("other"));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(
|
|
schema.may_alias(
|
|
{c10::SchemaArgType::input, 15}, {c10::SchemaArgType::output, 0}),
|
|
c10::Error);
|
|
ASSERT_THROW(
|
|
schema.may_alias(
|
|
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 15}),
|
|
c10::Error);
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Wildcard) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
|
|
ASSERT_TRUE(schema.may_alias(
|
|
{c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::output, 1}, {c10::SchemaArgType::input, 0}));
|
|
}
|
|
} // namespace utils
|
|
} // namespace torch
|