mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62442 For PythonMethodWrapper::setArgumentNames, make sure to use the correct method specified by method_name_ rather than using the parent model_ obj which itself _is_ callable, but that callable is not the right signature to extract. For Python vs Script, unify the behavior to avoid the 'self' parameter, so we only list the argument names to the unbound arguments which is what we need in practice. Test Plan: update unit test and it passes Reviewed By: alanwaketan Differential Revision: D29965283 fbshipit-source-id: a4e6a1d0f393f2a41c3afac32285548832da3fb4
50 lines
1.7 KiB
C++
50 lines
1.7 KiB
C++
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <torch/csrc/deploy/deploy.h>
|
|
#include <torch/script.h>
|
|
#include <torch/torch.h>
|
|
|
|
using namespace ::testing;
|
|
using namespace caffe2;
|
|
|
|
// TODO(T96218435): Enable the following tests in OSS.
|
|
TEST(IMethodTest, CallMethod) {
|
|
auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
|
|
auto script_method = script_model.get_method("forward");
|
|
|
|
torch::deploy::InterpreterManager manager(3);
|
|
torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
|
|
auto py_model = p.load_pickle("model", "model.pkl");
|
|
torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
|
|
|
|
auto input = torch::ones({10, 20});
|
|
auto output_py = py_method({input});
|
|
auto output_script = script_method({input});
|
|
EXPECT_TRUE(output_py.isTensor());
|
|
EXPECT_TRUE(output_script.isTensor());
|
|
auto output_py_tensor = output_py.toTensor();
|
|
auto output_script_tensor = output_script.toTensor();
|
|
|
|
EXPECT_TRUE(output_py_tensor.equal(output_script_tensor));
|
|
EXPECT_EQ(output_py_tensor.numel(), 200);
|
|
}
|
|
|
|
TEST(IMethodTest, GetArgumentNames) {
|
|
auto scriptModel = torch::jit::load(getenv("SIMPLE_JIT"));
|
|
auto scriptMethod = scriptModel.get_method("forward");
|
|
|
|
auto& scriptNames = scriptMethod.getArgumentNames();
|
|
EXPECT_EQ(scriptNames.size(), 1);
|
|
EXPECT_STREQ(scriptNames[0].c_str(), "input");
|
|
|
|
torch::deploy::InterpreterManager manager(3);
|
|
torch::deploy::Package package = manager.load_package(getenv("SIMPLE"));
|
|
auto pyModel = package.load_pickle("model", "model.pkl");
|
|
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
|
|
|
|
auto& pyNames = pyMethod.getArgumentNames();
|
|
EXPECT_EQ(pyNames.size(), 1);
|
|
EXPECT_STREQ(pyNames[0].c_str(), "input");
|
|
}
|