mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67941 I just found out that due to the round up of the Tensor storage sizes to multiples of 64 bytes, resizing is not actually triggered for a lot of our unit tests (23 OSS, 16 internal). Now they should be all fixed. Also moved a bunch of tests to `test_static_module.cc` so that `test_static_runtime.cc` now only contains operator tests. From now on, by default if `args2` is passed to `test_static_runtime`, at the end of the second iteration, it would check that the managed buffer's size is bigger than the previous size and enforce that. You can bypass the check for ops with constant output sizes, such as `aten::sum` without `dim` passed in. Test Plan: Facebook ``` buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest buck test //caffe2/benchmarks/static_runtime/fb:test_fb_operators ``` Reviewed By: swolchok Differential Revision: D32196204 fbshipit-source-id: 8425d9efe6b9a1c1e3807e576b1143efd7561c71
48 lines
1.1 KiB
C++
48 lines
1.1 KiB
C++
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
|
|
|
|
#pragma once
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
|
|
namespace c10 {
|
|
struct IValue;
|
|
}
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
struct Node;
|
|
class StaticModule;
|
|
|
|
namespace test {
|
|
|
|
// Given a model/function in jit or IR script, run the model/function
|
|
// with the jit interpreter and static runtime, and compare the results
|
|
void testStaticRuntime(
|
|
const std::string& source,
|
|
const std::vector<c10::IValue>& args,
|
|
const std::vector<c10::IValue>& args2 = {},
|
|
const bool use_allclose = false,
|
|
const bool use_equalnan = false,
|
|
const bool check_resize = true);
|
|
|
|
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
|
|
|
|
bool hasProcessedNodeWithName(
|
|
torch::jit::StaticModule& smodule,
|
|
const char* name);
|
|
|
|
at::Tensor getTensor(const at::IValue& ival);
|
|
|
|
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
|
|
|
|
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|