pytorch/benchmarks/static_runtime/test_utils.h
Donald Dong f7294cd865 [Static Runtime] Skip ReplaceWithCopy when inputs have writters (#69819)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69819

We should skip ReplaceWithCopy if the inputs to the operator can be updated during inference. For a set of tensors that share data, ReplaceWithCopy should not happen to any of them if there exists updates to any of them.

Currently, the check in place has missed some cases (suppose there exists updates, and uses <= 1). This diff addresses the missing cases by querying AliasDB.

Test Plan:
- Added test cases, including a one that is problematic before this diff
- CI

Reviewed By: mikeiovine

Differential Revision: D33052562

fbshipit-source-id: 61f87e471805f41d071a28212f2f457e8c6785e7
2021-12-14 09:39:49 -08:00

57 lines
1.4 KiB
C++

// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#pragma once
#include <string>
#include <vector>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace c10 {
struct IValue;
}
namespace torch {
namespace jit {
struct Node;
class StaticModule;
namespace test {
// Given a model/function in jit or IR script, run the model/function
// with the jit interpreter and static runtime, and compare the results
void testStaticRuntime(
const std::string& source,
const std::vector<c10::IValue>& args,
const std::vector<c10::IValue>& args2 = {},
const bool use_allclose = false,
const bool use_equalnan = false,
const bool check_resize = true);
std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script);
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
bool hasProcessedNodeWithName(
torch::jit::StaticModule& smodule,
const char* name);
at::Tensor getTensor(const at::IValue& ival);
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose = false,
const bool use_equalnan = false);
} // namespace test
} // namespace jit
} // namespace torch