pytorch/test/cpp/nativert/test_graph_signature.cpp
Yiming Zhou aeb734f519 [nativert] Move GraphSignature to pytorch core (#152969)
Summary:
Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72

Added an in-memory representation for input and output specs of a graph. The GraphSignature class models the input and output specs of an exported graph produced by torch.export, which holds the graph information deserialized from the pt2 archive package. Runtime relies on the GraphSignature for weight name lookup and weight loading.

The serialization schema is defined in torch/_export/serde/schema.py
See more at: https://docs.pytorch.org/docs/stable/export.html#torch.export.ExportGraphSignature

Test Plan: Added tests under `test/cpp/nativert/test_graph_signature.cpp`

Differential Revision: D73895378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152969
Approved by: https://github.com/swolchok
2025-05-20 21:49:56 +00:00

78 lines
2.8 KiB
C++

#include <gtest/gtest.h>
#include <torch/nativert/graph/GraphSignature.h>
namespace torch::nativert {
class GraphSignatureTest : public ::testing::Test {
protected:
// Member to hold the GraphSignature object
GraphSignature graph_sig;
void SetUp() override {
torch::_export::TensorArgument param_tensor_arg;
param_tensor_arg.set_name("param");
torch::_export::InputToParameterSpec param_input_spec;
param_input_spec.set_arg(param_tensor_arg);
param_input_spec.set_parameter_name("param");
torch::_export::InputSpec input_spec_0;
input_spec_0.set_parameter(param_input_spec);
torch::_export::TensorArgument input_tensor_arg;
input_tensor_arg.set_name("input");
torch::_export::Argument input_arg;
input_arg.set_as_tensor(input_tensor_arg);
torch::_export::UserInputSpec user_input_spec;
user_input_spec.set_arg(input_arg);
torch::_export::InputSpec input_spec_1;
input_spec_1.set_user_input(user_input_spec);
torch::_export::TensorArgument loss_tensor_arg;
loss_tensor_arg.set_name("loss");
torch::_export::LossOutputSpec loss_output_spec;
loss_output_spec.set_arg(loss_tensor_arg);
torch::_export::OutputSpec output_spec_0;
output_spec_0.set_loss_output(loss_output_spec);
torch::_export::TensorArgument output_tensor_arg;
output_tensor_arg.set_name("output");
torch::_export::Argument output_arg;
output_arg.set_as_tensor(output_tensor_arg);
torch::_export::UserOutputSpec user_output_spec;
user_output_spec.set_arg(output_arg);
torch::_export::OutputSpec output_spec_1;
output_spec_1.set_user_output(user_output_spec);
torch::_export::GraphSignature mock_storage;
mock_storage.set_input_specs({input_spec_0, input_spec_1});
mock_storage.set_output_specs({output_spec_0, output_spec_1});
// Initialize the GraphSignature object
graph_sig = GraphSignature(mock_storage);
}
};
// Test the constructor with a simple GraphSignature
TEST_F(GraphSignatureTest, ConstructorTest) {
std::vector<std::string_view> expected_params = {"param"};
EXPECT_EQ(graph_sig.parameters(), expected_params);
std::vector<std::string> expected_inputs = {"input"};
EXPECT_EQ(graph_sig.userInputs(), expected_inputs);
EXPECT_EQ(graph_sig.userInputs().size(), 1);
EXPECT_EQ(graph_sig.parameters().size(), 1);
EXPECT_EQ(graph_sig.lossOutput(), "loss");
std::vector<std::optional<std::string>> expected_outputs = {"output"};
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}
// Test the replaceAllUses method
TEST_F(GraphSignatureTest, ReplaceAllUsesTest) {
graph_sig.replaceAllUses("output", "new_output");
std::vector<std::optional<std::string>> expected_outputs = {"new_output"};
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}
} // namespace torch::nativert