pytorch/benchmarks/static_runtime/test_utils.h
Mike Iovine 238dded10f [SR] Graph pass to create owned refs of special IValues (#69835)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69835

`StaticRuntimeBlockRunner` moves its outputs to the return value at the end of `run_impl`. However, there's a corner case where this can cause problems. If we return a constant, then the only reference in the `constants_` array can be destroyed by this move. We could add special logic to handle this in `run_impl`. But since this is a relatively rare corner case, it's simpler to just add an op that does nothing but create an owned reference to its input. This owned reference can be safely moved out of `StaticRuntimeBlockRunner`.

Note that this also applies to returned values in sub-blocks that are from outer scopes.
ghstack-source-id: 148186452

Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`

Added a new unit test with a graph that simply returns a constant.

Tests with sub-blocks at top of stack.

Reviewed By: d1jang

Differential Revision: D33047519

fbshipit-source-id: 22b6058f0d1da8a6d1d61a6f2866bc518bff482b
(cherry picked from commit a8f89a12ee)
2022-02-02 19:30:50 +00:00

59 lines
1.6 KiB
C++

// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#pragma once
#include <string>
#include <vector>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace c10 {
struct IValue;
}
namespace torch {
namespace jit {
struct Node;
class StaticModule;
namespace test {
// Given a model/function in jit or IR script, run the model/function
// with the jit interpreter and static runtime, and compare the results
void testStaticRuntime(
const std::string& source,
const std::vector<c10::IValue>& args,
const std::vector<c10::IValue>& args2 = {},
const bool use_allclose = false,
const bool use_equalnan = false,
const bool check_resize = true);
std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script);
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
bool hasProcessedNodeWithName(
torch::jit::StaticModule& smodule,
const char* name);
at::Tensor getTensor(const at::IValue& ival);
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose = false,
const bool use_equalnan = false);
} // namespace test
} // namespace jit
} // namespace torch