mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 Added an in-memory representation for input and output specs of a graph. The GraphSignature class models the input and output specs of an exported graph produced by torch.export, which holds the graph information deserialized from the pt2 archive package. Runtime relies on the GraphSignature for weight name lookup and weight loading. The serialization schema is defined in torch/_export/serde/schema.py See more at: https://docs.pytorch.org/docs/stable/export.html#torch.export.ExportGraphSignature Test Plan: Added tests under `test/cpp/nativert/test_graph_signature.cpp` Differential Revision: D73895378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152969 Approved by: https://github.com/swolchok
78 lines
2.8 KiB
C++
78 lines
2.8 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/nativert/graph/GraphSignature.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
class GraphSignatureTest : public ::testing::Test {
|
|
protected:
|
|
// Member to hold the GraphSignature object
|
|
GraphSignature graph_sig;
|
|
|
|
void SetUp() override {
|
|
torch::_export::TensorArgument param_tensor_arg;
|
|
param_tensor_arg.set_name("param");
|
|
torch::_export::InputToParameterSpec param_input_spec;
|
|
param_input_spec.set_arg(param_tensor_arg);
|
|
param_input_spec.set_parameter_name("param");
|
|
torch::_export::InputSpec input_spec_0;
|
|
input_spec_0.set_parameter(param_input_spec);
|
|
|
|
torch::_export::TensorArgument input_tensor_arg;
|
|
input_tensor_arg.set_name("input");
|
|
torch::_export::Argument input_arg;
|
|
input_arg.set_as_tensor(input_tensor_arg);
|
|
torch::_export::UserInputSpec user_input_spec;
|
|
user_input_spec.set_arg(input_arg);
|
|
torch::_export::InputSpec input_spec_1;
|
|
input_spec_1.set_user_input(user_input_spec);
|
|
|
|
torch::_export::TensorArgument loss_tensor_arg;
|
|
loss_tensor_arg.set_name("loss");
|
|
torch::_export::LossOutputSpec loss_output_spec;
|
|
loss_output_spec.set_arg(loss_tensor_arg);
|
|
torch::_export::OutputSpec output_spec_0;
|
|
output_spec_0.set_loss_output(loss_output_spec);
|
|
|
|
torch::_export::TensorArgument output_tensor_arg;
|
|
output_tensor_arg.set_name("output");
|
|
torch::_export::Argument output_arg;
|
|
output_arg.set_as_tensor(output_tensor_arg);
|
|
torch::_export::UserOutputSpec user_output_spec;
|
|
user_output_spec.set_arg(output_arg);
|
|
torch::_export::OutputSpec output_spec_1;
|
|
output_spec_1.set_user_output(user_output_spec);
|
|
|
|
torch::_export::GraphSignature mock_storage;
|
|
mock_storage.set_input_specs({input_spec_0, input_spec_1});
|
|
mock_storage.set_output_specs({output_spec_0, output_spec_1});
|
|
|
|
// Initialize the GraphSignature object
|
|
graph_sig = GraphSignature(mock_storage);
|
|
}
|
|
};
|
|
|
|
// Test the constructor with a simple GraphSignature
|
|
TEST_F(GraphSignatureTest, ConstructorTest) {
|
|
std::vector<std::string_view> expected_params = {"param"};
|
|
EXPECT_EQ(graph_sig.parameters(), expected_params);
|
|
|
|
std::vector<std::string> expected_inputs = {"input"};
|
|
EXPECT_EQ(graph_sig.userInputs(), expected_inputs);
|
|
|
|
EXPECT_EQ(graph_sig.userInputs().size(), 1);
|
|
EXPECT_EQ(graph_sig.parameters().size(), 1);
|
|
EXPECT_EQ(graph_sig.lossOutput(), "loss");
|
|
|
|
std::vector<std::optional<std::string>> expected_outputs = {"output"};
|
|
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
|
|
}
|
|
|
|
// Test the replaceAllUses method
|
|
TEST_F(GraphSignatureTest, ReplaceAllUsesTest) {
|
|
graph_sig.replaceAllUses("output", "new_output");
|
|
std::vector<std::optional<std::string>> expected_outputs = {"new_output"};
|
|
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
|
|
}
|
|
|
|
} // namespace torch::nativert
|