mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
- Created may_alias method in FunctionSchema to publicize aliasing information about inputs and outputs of a schema. - Tested may_alias methods for basic functionality, exceptions, and wildcard functionality. **Cases where elements of a container alias another argument will be handled with a new may_contain_alias method which will be created in a later pr** Pull Request resolved: https://github.com/pytorch/pytorch/pull/80918 Approved by: https://github.com/davidberard98
58 lines
2.2 KiB
C++
58 lines
2.2 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/utils/schema_info.h>
|
|
|
|
namespace torch {
|
|
namespace utils {
|
|
TEST(FunctionSchemaIsMutableTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.is_mutable(0));
|
|
ASSERT_TRUE(schema.is_mutable("self"));
|
|
ASSERT_FALSE(schema.is_mutable(1));
|
|
ASSERT_FALSE(schema.is_mutable("other"));
|
|
ASSERT_FALSE(schema.is_mutable(2));
|
|
ASSERT_FALSE(schema.is_mutable("alpha"));
|
|
}
|
|
|
|
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
|
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Basic) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_TRUE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::input, 0}));
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
|
ASSERT_THROW(
|
|
schema.may_alias(
|
|
{c10::SchemaArgType::input, 15}, {c10::SchemaArgType::output, 0}),
|
|
c10::Error);
|
|
ASSERT_THROW(
|
|
schema.may_alias(
|
|
{c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 15}),
|
|
c10::Error);
|
|
}
|
|
|
|
TEST(FunctionSchemaMayAliasTest, Wildcard) {
|
|
c10::FunctionSchema schema = torch::jit::parseSchema(
|
|
"aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
|
|
ASSERT_TRUE(schema.may_alias(
|
|
{c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
|
|
ASSERT_FALSE(schema.may_alias(
|
|
{c10::SchemaArgType::output, 1}, {c10::SchemaArgType::input, 0}));
|
|
}
|
|
} // namespace utils
|
|
} // namespace torch
|