pytorch/test/cpp/jit/test_misc.cpp
Edward Z. Yang 5c6f5439b7 Implement SymBool (#92149)
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.)

This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below.

* I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix.
* ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR.
* (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides`
* On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use.
* There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable.

TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149
Approved by: https://github.com/albanD, https://github.com/Skylion007
2023-01-21 02:21:56 +00:00

3139 lines
98 KiB
C++

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/guard_elimination.h>
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/insert_guards.h>
#include <torch/csrc/jit/passes/liveness.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/symbolic_script.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <onnx/onnx_pb.h>
#include <c10/util/Exception.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
size_t i = 0;
out << "{";
for (auto&& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "}";
return out;
}
TEST(InternedStringsTest, Basic) {
ASSERT_EQ(prim::Param, Symbol::prim("Param"));
ASSERT_EQ(prim::Return, Symbol::prim("Return"));
ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
Symbol newsym = Symbol::aten("__NEW_SYMBOL");
size_t symstart = newsym;
ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
// TODO: This test is a bit too close to the implementation details.
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
}
TEST(FromQualStringTest, Basic) {
ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
ASSERT_EQ(
Symbol::fromQualString("::").ns().toQualString(),
std::string("namespaces::"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").toUnqualString(),
std::string("param"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
std::string("new_ns"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns(),
Symbol::fromQualString("namespaces::new_ns"));
auto bad_inputs = {"scope", ":", ""};
for (auto input : bad_inputs) {
try {
Symbol::fromQualString(input);
ASSERT_TRUE(0);
} catch (const std::exception& c) {
}
}
}
TEST(THNNConvTest, Basic) {
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
std::vector<int64_t> kernel_size = {3, 5};
std::vector<int64_t> stride = {1, 2};
std::vector<int64_t> padding = {2, 1};
constexpr int out_channels = 5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn(
{out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});
// run forward eagerly
at::Tensor output = at::_slow_conv2d_forward(
input, weight, kernel_size, bias, stride, padding);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
std::tie(grad_input, grad_weight, grad_bias) = at::_slow_conv2d_backward(
grad_output,
input,
weight,
kernel_size,
stride,
padding,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto ksz_val = graph->insertConstant(kernel_size);
auto kst_val = graph->insertConstant(stride);
auto pad_val = graph->insertConstant(padding);
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
Value* conv = graph->insert(
aten::_slow_conv2d_forward,
{inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto outputs = conv->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
TEST(ATenNativeBatchNormTest, Basic) {
// aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
// running_mean, Tensor running_var, bool training, float momentum, float eps)
// -> (Tensor, Tensor, Tensor)
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
bool training = true;
float momentum = 0.9;
float eps = 1e-5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({input_size[1]});
at::Tensor bias = torch::randn({input_size[1]});
at::Tensor running_mean = torch::randn({input_size[1]});
at::Tensor running_var = torch::randn({input_size[1]});
// running_mean and running_var are changed in-place, so clone and send them
at::Tensor running_mean_eager = running_mean.clone();
at::Tensor running_var_eager = running_var.clone();
at::Tensor running_mean_jit = running_mean.clone();
at::Tensor running_var_jit = running_var.clone();
// run forward eagerly
at::Tensor output, savemean, saveinvstd;
std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
input,
weight,
bias,
running_mean_eager,
running_var_eager,
training,
momentum,
eps);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
at::Tensor grad_savemean =
torch::zeros_like(savemean, at::MemoryFormat::Preserve);
at::Tensor grad_saveinvstd =
torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
// aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
// weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
// save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
// Tensor, Tensor)
std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
grad_output,
input,
weight,
running_mean_eager,
running_var_eager,
savemean,
saveinvstd,
training,
eps,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto training_val = graph->insertConstant(IValue(training));
auto momentum_val = graph->insertConstant(IValue(momentum));
auto eps_val = graph->insertConstant(IValue(eps));
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
auto running_meang = graph->addInput("running_mean");
auto running_varg = graph->addInput("running_var");
Value* bn = graph->insert(
aten::native_batch_norm,
{inputg,
weightg,
biasg,
running_meang,
running_varg,
training_val,
momentum_val,
eps_val});
auto outputs = bn->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensors_in.push_back(running_mean_jit);
tensors_in.push_back(running_var_jit);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_savemean);
tensor_grads_in.push_back(grad_saveinvstd);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(savemean);
expected_tensors_out.push_back(saveinvstd);
expected_tensors_out.push_back(running_mean_eager);
expected_tensors_out.push_back(running_var_eager);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
tensors_out.push_back(running_mean_jit);
tensors_out.push_back(running_var_jit);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
TEST(CustomFusionTest, Basic) {
#if defined(FBCODE_CAFFE2)
return;
#endif
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4)):
%2 : Tensor = aten::mul(%0, %1)
%3 : Tensor = aten::mul(%2, %0)
return (%3))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
torch::jit::overrideCanFuseOnCPU(true);
CustomFuseGraph(
g,
[](Node* n) { return n->kind() != prim::Param; },
Symbol::fromQualString("prim::FusionGroup"));
torch::jit::overrideCanFuseOnCPU(false);
const auto& nodes = g->nodes();
auto fusion_group =
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
return node->kind() == Symbol::fromQualString("prim::FusionGroup");
});
AT_ASSERT(fusion_group != nodes.end());
auto subgraph = fusion_group->g(attr::Subgraph);
auto hits = 0;
// two multiplications
for (const auto& n : subgraph->nodes()) {
(void)n;
hits++;
}
AT_ASSERT(hits == 2);
}
TEST(CustomFusionTest, NestedBlocks) {
#if defined(FBCODE_CAFFE2)
return;
#endif
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4),
%2 : Float(2, 3, 4)):
%3 : int = prim::Constant[value=1]()
%4 : Tensor = prim::If(%2)
block0():
%5 : Tensor = aten::mul(%0, %2)
%6 : Tensor = aten::mul(%5, %1)
-> (%6)
block1():
%7 : Tensor = aten::add(%0, %2, %3)
%8 : Tensor = aten::add(%7, %1, %3)
-> (%8)
%9 : Tensor = aten::add(%4, %2, %3)
return (%4))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
CustomFuseGraph(
g,
[](Node* n) { return n->kind() == aten::mul; },
Symbol::fromQualString("prim::FusionGroup"));
// Could be done in more efficient ways, but this is only a test.
std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
Symbol s) {
for (auto node : b->nodes()) {
if (node->kind() == s)
return true;
for (auto nested_b : node->blocks())
if (dfs(nested_b, s))
return true;
}
return false;
};
AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
}
static const auto cf_examples = R"JIT(
def if_test(a, b):
# FIXME: use 0 instead of a.
# c = 0
c = a
if bool(a < b):
c = b
else:
c = a
return c
def if_one(a, b):
c = b
if bool(a < b):
c = a
return c
def while_test(a, i):
while bool(i < 3):
a *= a
i += 1
return a
)JIT";
TEST(ControlFlowTest, Basic) {
auto cu = compile(cf_examples);
auto run = [&](const std::string& name, std::vector<IValue> stack) {
auto graph = toGraphFunction(cu->get_function(name)).graph();
Code code(graph, "");
InterpreterState interp(code);
interp.run(stack);
return stack;
};
auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
return V(run(name, {L(a), L(b)})[0]);
};
ASSERT_EQ(2, run_binary("if_test", 1, 2));
ASSERT_EQ(3, run_binary("if_test", 3, 2));
ASSERT_EQ(2, run_binary("if_one", 2, 3));
ASSERT_EQ(2, run_binary("if_one", 3, 2));
ASSERT_EQ(256, run_binary("while_test", 2, 0));
}
#if defined(__has_feature)
#if __has_feature(address_sanitizer)
#define HAS_ASANUBSAN 1
#endif
#endif
#ifndef HAS_ASANUBSAN
// This test fails vptr UBSAN checks
TEST(ProtoTest, Basic) {
::ONNX_NAMESPACE::ModelProto proto;
proto.set_producer_name("foo");
}
#endif
// test a few features that are not directly used in schemas yet
TEST(SchemaParserTest, NestedArrays) {
// nested arrays
auto s = parseSchema("at::what(int[][4] foo) -> ()");
ASSERT_TRUE(s.arguments().at(0).N() == 4);
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
}
TEST(SchemaParserTest, OutVariant) {
auto schema_with_out = parseSchema(
"at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
auto schema_without_out =
parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
for (const auto& arg : schema_without_out.arguments()) {
ASSERT_TRUE(!arg.is_out());
}
auto schema_with_is_write = parseSchema(
"aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
for (const auto& arg : schema_with_is_write.arguments()) {
ASSERT_TRUE(!arg.is_out());
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest, NamedReturns) {
// named returns
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
auto s3 =
parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
}
TEST(SchemaParserTest, Futures) {
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(
*s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
}
TEST(SchemaParserTest, AnnotatedAliasSets) {
// test tensor with annotated alias sets
parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
}
TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
const auto s = parseSchema(
"at::foo(Tensor(a!) self, Tensor(b!)[] out)"
" -> ()");
const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
ASSERT_TRUE(
selfAliasInfo->beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(selfAliasInfo->isWrite());
ASSERT_TRUE(outAliasInfo->isWrite());
ASSERT_TRUE(outAliasInfo->beforeSets().empty());
ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
auto containedType = outAliasInfo->containedTypes()[0];
ASSERT_TRUE(containedType.isWrite());
ASSERT_TRUE(
containedType.beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
}
TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
EXPECT_THAT(
[]() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
::testing::Throws<std::runtime_error>(::testing::Property(
&std::runtime_error::what,
::testing::HasSubstr("expected ident but found '!' here"))));
}
TEST(SchemaParserTest, BeforeAfterSets) {
const auto s = parseSchema(
"at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
ASSERT_NE(aliasInfo, nullptr);
ASSERT_TRUE(
aliasInfo->beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo->isWrite());
// Check the contained types
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
const auto expected = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
Symbol::fromQualString("alias::c"),
};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
TEST(SchemaParserTest, BeforeAfterSets2) {
const auto s = parseSchema(
"at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
ASSERT_NE(aliasInfo, nullptr);
ASSERT_EQ(
aliasInfo->beforeSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_EQ(
aliasInfo->afterSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo->isWrite());
ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
// Check the contained types
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
const auto expectedBefore = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
};
const auto expectedAfter = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
TEST(TopologicalIndexTest, Basic) {
Graph graph;
auto node1 = graph.create(prim::AutogradZero);
auto node2 = graph.create(prim::AutogradZero);
auto node3 = graph.create(prim::AutogradZero);
auto node4 = graph.create(prim::AutogradZero);
graph.appendNode(node4);
graph.prependNode(node1);
node2->insertAfter(node1);
node3->insertBefore(node4);
// nodes should be in numerical order
ASSERT_TRUE(node1->isBefore(node2));
ASSERT_TRUE(node1->isBefore(node3));
ASSERT_TRUE(node1->isBefore(node4));
ASSERT_TRUE(node2->isAfter(node1));
ASSERT_TRUE(node2->isBefore(node3));
ASSERT_TRUE(node2->isBefore(node4));
ASSERT_FALSE(node3->isBefore(node1));
ASSERT_FALSE(node3->isBefore(node2));
ASSERT_FALSE(node3->isAfter(node4));
// Built up a block structure
// node3
// /\ ...
// A B block1
// \ ...
// C block2
auto block1 = node3->addBlock();
auto A = graph.create(prim::AutogradZero);
block1->appendNode(A);
auto B = graph.create(prim::AutogradZero);
block1->appendNode(B);
auto block2 = B->addBlock();
auto C = graph.create(prim::AutogradZero);
block2->appendNode(C);
// Check isAfter on different block levels
ASSERT_TRUE(node1->isBefore(A));
ASSERT_TRUE(A->isBefore(B));
ASSERT_TRUE(A->isBefore(C));
// make sure things don't blow up on deletions
node2->destroy();
auto node2p = graph.create(prim::AutogradZero);
node2p->insertAfter(node1);
ASSERT_TRUE(node1->isBefore(node2p));
ASSERT_TRUE(node2p->isBefore(node3));
}
TEST(TopologicalIndexTest, Reindex) {
// Induce reindexing to test that path
Graph graph;
std::map<size_t, Node*> nodes;
auto anchor = graph.create(prim::AutogradZero);
graph.appendNode(anchor);
// Inserting to the same place a lot will trigger reindexing
for (auto i = 0; i < 100; ++i) {
auto n = graph.create(prim::AutogradZero);
n->insertAfter(anchor);
nodes[i] = n;
}
// Nodes should be in reverse order
for (auto i = 0; i < 100; ++i) {
for (auto j = i + 1; j < 100; ++j) {
ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
}
}
}
at::Tensor invokeTestRecordFunction(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto t2 = t.pow(2);
return t2;
}
static const auto invokeTestRecordFunction_JIT = R"JIT(
def foo(self, t):
t2 = t.pow(2)
return t2
def forward(self, t):
return self.foo(t)
)JIT";
at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto module = std::make_shared<script::Module>(
"RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
module->define(invokeTestRecordFunction_JIT);
return module->forward({t}).toTensor();
}
using TracedTestValues =
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
void checkTracedInputs(const TracedTestValues& inputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& input : inputs) {
const auto& fn = std::get<0>(input);
const auto& sizes = std::get<1>(input);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "aten::pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
TORCH_CHECK(sizes[1].empty());
} else if (fn == "aten::mul") {
found_mul = true;
TORCH_CHECK(sizes.size() > 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
void checkTracedOutputs(const TracedTestValues& outputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& output : outputs) {
const auto& fn = std::get<0>(output);
const auto& sizes = std::get<1>(output);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.empty());
} else if (fn == "aten::pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "aten::mul") {
found_mul = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++(*cnt);
} else {
bad_scope = true;
}
return nullptr;
}
template <RecordScope scope, size_t* cnt>
void pushScopedCallback() {
at::addGlobalCallback(
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
.scopes({scope}));
}
// These cannot be function-local because that would prohibit them
// from being used as template arguments prior to C++17.
static size_t fun_cnt;
static size_t ts_fun_cnt;
static size_t user_scope_cnt;
void checkScopeCallbacks() {
static bool found_function_scope;
static bool found_method_scope;
static bool found_user_scope;
found_function_scope = false;
found_method_scope = false;
found_user_scope = false;
at::addGlobalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
if (fn.scope() == at::RecordScope::FUNCTION &&
std::string(fn.name()) == "test_function") {
found_function_scope = true;
}
if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
std::string(fn.name()) == "test_method") {
found_method_scope = true;
}
if (fn.scope() == at::RecordScope::USER_SCOPE &&
std::string(fn.name()) == "test_user_scope") {
found_user_scope = true;
}
return nullptr;
}));
bad_scope = false;
fun_cnt = 0;
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
ts_fun_cnt = 0;
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
user_scope_cnt = 0;
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
TORCH_CHECK(at::hasCallbacks());
{
RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
{ RECORD_FUNCTION("test_function", {}); }
{ RECORD_USER_SCOPE("test_user_scope"); }
}
TORCH_CHECK(!bad_scope);
TORCH_CHECK(fun_cnt == 1);
TORCH_CHECK(ts_fun_cnt == 1);
TORCH_CHECK(user_scope_cnt == 1);
TORCH_CHECK(found_function_scope);
TORCH_CHECK(found_method_scope);
TORCH_CHECK(found_user_scope);
}
static TracedTestValues traced_inputs;
static TracedTestValues traced_outputs;
static std::unordered_set<std::string> ts_input_names;
static std::unordered_set<std::string> ts_output_names;
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
const RecordFunction& fn) {
if (fn.scope() == RecordScope::FUNCTION) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& input : inputs) {
if (input.isTensor()) {
sizes.push_back(input.toTensor().sizes().vec());
} else if (input.isScalar()) {
// NOLINTNEXTLINE(modernize-use-emplace)
sizes.push_back(std::vector<int64_t>());
}
}
traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_input_names.insert(fn.name());
}
return nullptr;
}
void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
if (fn.scope() == RecordScope::FUNCTION) {
auto outputs = fn.outputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& output : outputs) {
if (output.isTensor()) {
sizes.push_back(output.toTensor().sizes().vec());
} else if (output.isScalar()) {
sizes.emplace_back();
}
}
traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_output_names.insert(fn.name());
}
}
TEST(RecordFunctionTest, TracedTestInputsOutputs) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// [(fn, [[sizes], [sizes], ...]), ...]
addGlobalCallback(
RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
.needsInputs(true)
.needsOutputs(true));
TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
{
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = invokeTestRecordFunction(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
eager_inputs = traced_inputs;
eager_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
TORCH_CHECK(ts_input_names.empty());
TORCH_CHECK(ts_output_names.empty());
t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
t2 = invokeTestRecordFunctionJIT(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
jit_inputs = traced_inputs;
jit_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
}
TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
checkTracedInputs(eager_inputs);
checkTracedOutputs(eager_outputs);
checkTracedInputs(jit_inputs);
checkTracedOutputs(jit_outputs);
at::clearCallbacks();
}
static int sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
if (std::string(fn.name()) == "test") {
++sampled_cb_ctr;
}
return nullptr;
}
static int non_sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
if (std::string(fn.name()) == "test") {
++non_sampled_cb_ctr;
}
return nullptr;
}
TEST(RecordFunctionTest, SampledCallbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// test sampled callbacks
sampled_cb_ctr = 0;
auto setup_sampled_callback = [](double sampling_prob) {
return addGlobalCallback(
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
};
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
auto handle = setup_sampled_callback(0.5);
auto run_test_function = []() {
auto t = torch::randn({1, 2, 3}, at::kCPU);
for (auto k = 0; k < 1000; k++) {
invokeTestRecordFunction(t);
}
};
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 1000);
TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
sampled_cb_ctr = 0;
removeCallback(handle);
handle = setup_sampled_callback(0.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 2000);
TORCH_CHECK(sampled_cb_ctr == 0);
sampled_cb_ctr = 0;
removeCallback(handle);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
handle = setup_sampled_callback(1.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 3000);
TORCH_CHECK(sampled_cb_ctr == 1000);
clearCallbacks();
// test the scope of the callbacks
checkScopeCallbacks();
clearCallbacks();
}
TEST(RecordFunctionTest, RecordFunctionGuard) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::vector<std::string> fn_names;
static std::mutex guard_mtx;
// check record function guard
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::lock_guard<std::mutex> lock(guard_mtx);
// NOLINTNEXTLINE(modernize-use-emplace)
fn_names.push_back(fn.name());
return nullptr;
}));
{
RecordFunctionGuard g1(false);
{
RECORD_USER_SCOPE("A");
{
RecordFunctionGuard g2(true);
RECORD_USER_SCOPE("B");
{
DisableRecordFunctionGuard g3;
RECORD_USER_SCOPE("C");
}
}
{ RECORD_USER_SCOPE("D"); }
}
}
TORCH_CHECK(fn_names.size() == 1);
TORCH_CHECK(fn_names[0] == "B");
clearCallbacks();
}
static std::vector<size_t> ids;
template <size_t id>
auto add_remove_test_add_cb() {
return addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(id);
return nullptr;
}));
}
TEST(RecordFunctionTest, Callbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
auto h1 = add_remove_test_add_cb<1>();
add_remove_test_add_cb<2>();
auto h3 = add_remove_test_add_cb<3>();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 3);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
ids.clear();
removeCallback(h1);
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
ids.clear();
removeCallback(h3);
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
clearCallbacks();
// thread local / global callbacks
ids.clear();
add_remove_test_add_cb<1>();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
auto th = std::thread([]() {
addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(2);
return nullptr;
}));
{ RECORD_USER_SCOPE("test_thread"); }
});
th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
clearCallbacks();
// START: thread local / global context check callbacks
struct TestContext : public ObserverContext {
int a{0};
std::string b;
};
ids.clear();
{ // START: global test
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction&
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = 123;
ctx->b = "test_str";
ids.push_back(1);
return ctx;
},
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == 123);
TORCH_CHECK(ctx->b == "test_str");
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
} // END: global test
{ // START: thread local test
auto ctx_th = std::thread([]() {
const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction&
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = 234;
ctx->b = "test_thread_str";
ids.push_back(2);
return ctx;
},
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == 234);
TORCH_CHECK(ctx->b == "test_thread_str");
}));
// Will call both global and thread local callbacks.
{ RECORD_USER_SCOPE("test_thread"); }
});
ctx_th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
} // END: thread local test
clearCallbacks();
}
TEST(RecordFunctionTest, ShouldRun) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static bool ran = false;
auto handle = addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ran = true;
return nullptr;
}));
{ RECORD_USER_SCOPE("test"); }
EXPECT_TRUE(ran) << "first run didn't happen";
ran = false;
disableCallback(handle);
{ RECORD_USER_SCOPE("test"); }
EXPECT_FALSE(ran) << "second run happened but shouldn't have";
ran = false;
reenableCallback(handle);
{ RECORD_USER_SCOPE("test"); }
EXPECT_TRUE(ran) << "run after re-enable didn't happen";
ran = false;
clearCallbacks();
}
TEST(RecordFunctionTest, Basic) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::string recorded_op;
static bool has_ids = false;
// test propagation of TLS callbacks
std::thread t([]() {
RecordFunctionGuard enable_rec_fn;
auto handle = addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
recorded_op = fn.name();
return nullptr;
}));
ThreadLocalState state;
std::thread t_child([state]() {
ThreadLocalStateGuard g_tls(state);
RECORD_USER_SCOPE("test_in_thread");
});
t_child.join();
EXPECT_EQ(recorded_op, "test_in_thread");
removeCallback(handle);
});
t.join();
clearCallbacks();
// test set ids
addGlobalCallback(
RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
})
.needsIds(true));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(has_ids);
clearCallbacks();
has_ids = false;
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(!has_ids);
clearCallbacks();
}
TEST(RecordFunctionTest, OperatorNameOverload) {
static std::set<std::string> operator_names;
at::addGlobalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
c10::optional<c10::OperatorName> op_name =
fn.operator_name();
if (op_name.has_value()) {
operator_names.insert(c10::toString(*op_name));
} else {
operator_names.insert("No Operator Name");
}
return nullptr;
})
.scopes({at::RecordScope::FUNCTION}));
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(false);
auto t2 = t.pow(2);
at::clearCallbacks();
EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
<< "Expected that all traced operators had an associated OperatorName object";
EXPECT_TRUE(operator_names.count("aten::randn") == 1)
<< "Expected aten::randn to have been called and recorded, but it was not";
EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
<< "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
}
class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
public:
int getModelId() const {
return model_id_;
}
void setModelId(int model_id) {
model_id_ = model_id;
}
// NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
virtual ~TestThreadLocalDebugInfo() {}
private:
int model_id_ = 0;
};
void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
TORCH_CHECK(debug_info != nullptr);
auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
TORCH_CHECK(test_debug_info != nullptr);
TORCH_CHECK(test_debug_info->getModelId() == model_id);
}
TEST(ThreadLocalDebugInfoTest, Basic) {
static std::atomic<bool> done{false};
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
debug_info->setModelId(42);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
}
// check that thread local debug info is propagated through fork calls
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
});
}
while (!done) {
}
// check that thread local debug info is propagated through backward pass
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
done = false;
auto handle = addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
return nullptr;
}));
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = t.pow(2);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
}
removeCallback(handle);
TORCH_CHECK(done);
// check nested debug info
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
{
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
{
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
debug_info->setModelId(314);
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
{
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = false;
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = true;
});
while (!done) {
}
}
}
}
}
}
TEST(TestSymIntArrayRef, BasicConversion) {
const size_t X = 2, Y = 4, Z = 5;
std::vector<int64_t> tgt_size_v{2, 4, 5};
std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
auto a = at::randn({1, 4, 1}, at::kCPU);
auto b = a.expand_symint(tgt_size);
auto c = a.expand(tgt_size_v);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
static const size_t LENGTH = 5;
auto a = at::randn({10}, at::kCPU);
c10::SymInt si(LENGTH);
auto b = a.narrow_copy_symint(0, 0, si);
auto c = a.narrow(0, 0, LENGTH);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, NarrowCopy) {
static const size_t LENGTH = 5;
auto a = at::randn({10}, at::kCPU);
auto b = a.narrow_copy(0, 0, LENGTH);
auto c = a.narrow(0, 0, LENGTH);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, AddSymbolicInt) {
c10::SymInt a(5);
c10::SymInt b(3);
ASSERT_TRUE((a + b).expect_int() == 8);
}
TEST(FallbackGraphsTest, Basic) {
auto x = at::randn({1}, at::kCPU);
auto y = at::randn({1}, at::kCPU);
auto stack = createStack({x.clone(), y.clone()});
auto graph_string = R"IR(
graph(%0 : Float(1),
%1 : Float(1)):
%2 : Tensor = aten::mul(%0, %1)
%3 : Tensor = aten::mul(%2, %0)
return (%3))IR";
auto graph = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, graph.get());
{
Code code(graph, "");
InterpreterState interpreter{code};
interpreter.run(stack);
}
at::Tensor et;
pop(stack, et);
float ef = et.item<float>();
{
EnableProfilingGuard epg;
GraphFunction f("fallbackGraphs", graph, nullptr);
for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
stack.emplace_back(x.clone());
stack.emplace_back(y.clone());
if (i == getNumProfiledRuns()) {
// we will be modifying a profiled graph
// before ProfilingGraphExecutor
// will optimize it in the next iteration
auto opt_graph = lastExecutedOptimizedGraph();
// this is safe to do since we are done profiling
ProfilingRecord::removeProfileCounter(opt_graph->block());
replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
auto it = opt_graph->block()->nodes().begin();
ASSERT_EQ(it->kind(), prim::FallbackGraph);
auto fallback = *it++;
ASSERT_EQ(it, opt_graph->block()->nodes().end());
ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
testing::FileCheck()
.check("Tensor = aten::mul")
->check("Tensor = aten::mul")
->run(*fallback->g(attr::Subgraph));
}
f.run(stack);
at::Tensor at;
pop(stack, at);
float af = at.item<float>();
ASSERT_EQ(af, ef);
}
auto opt_graph = lastExecutedOptimizedGraph();
testing::FileCheck()
.check("(Tensor) = prim::CallFunction")
->run(*opt_graph);
}
}
// TODO this test wasn't running and is broken.
// TEST(AutogradProfilerTest, Basic) {
// constexpr int batch_size = 4;
// constexpr int input_size = 256;
// constexpr int seq_len = 32;
// int hidden_size = 2 * input_size;
// auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
// auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
// auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
// auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
// auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
// std::stringstream ss;
// {
// RecordProfile guard(ss);
// for (size_t i = 0; i < 100; ++i) {
// std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
// }
// }
// std::string result = ss.str();
// size_t count = 0;
// for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
// count++, pos++) {
// }
// ASSERT_EQ((count, 200);
// }
TEST(NoneSchemaMatchTest, Basic) {
RegisterOperators reg({
Operator(
"prim::test_none() -> int?",
[](Stack& stack) { push(stack, IValue()); },
aliasAnalysisFromSchema()),
Operator(
"prim::is_none(int? a) -> bool",
[](Stack& stack) {
IValue a = pop(stack);
if (a.isNone()) {
push(stack, true);
} else {
push(stack, false);
}
},
aliasAnalysisFromSchema()),
});
// Constant propagation will run test_none and produce a None,
// testing that its type is set appropriately and schema matching doesn't
// fail when running is_none
auto r = std::make_shared<Graph>();
auto& g = *r;
auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
g.registerOutput(out_bool);
ConstantPropagation(r);
auto nodes = r->block()->nodes();
// checking that constant propagation ran wo/failure
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}
static int testPassValue = 0;
void fakePass(std::shared_ptr<Graph>& g) {
testPassValue++;
return;
}
RegisterPass p(fakePass);
TEST(PassManagementTest, Basic) {
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%a):
return (%a))IR",
&*graph);
std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
GraphExecutor executor(graph, "");
executor.run(stack);
return stack;
};
run(graph, stack);
// we will not run fusion in simple mode
if (!getExecutorMode()) {
AT_ASSERT(testPassValue);
}
}
static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
auto ptp = typ->expect<TensorType>();
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}
static void checkShape(
Node* n,
std::vector<int64_t> expected,
bool prev = true) {
auto profile = (prev) ? n->inputs().at(0)->node() : n;
checkShape(profile->output()->type(), expected);
}
void count_(
Block* block,
const std::function<bool(Node* n)>& pred,
size_t& count) {
for (Node* n : block->nodes()) {
if (pred(n)) {
count++;
}
for (Block* ib : n->blocks()) {
count_(ib, pred, count);
}
}
}
size_t countNodes(
const std::shared_ptr<Graph>& graph,
const std::function<bool(Node* n)>& pred) {
size_t count = 0;
count_(graph->block(), pred, count);
return count;
}
bool true_pred(Node* n) {
return true;
};
bool is_loop(Node* n) {
return n->kind() == prim::Loop;
};
TEST(LoopPeelerTest, NoInductionVariableUse) {
// do not use an induction variable explicitly
static const auto str_func_def = R"JIT(
def test_peel_n_times():
sum = 0
for i in range(10):
sum += 2
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 20);
}
// test peeling more than one iteration
{
LoopsPeeler peeler(true_pred, 3);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 20);
}
}
TEST(LoopPeelerTest, YesInductionVariableUse) {
// uses the induction variable
static const auto str_func_def = R"JIT(
def test_peel_n_times():
sum = 0
for i in range(10):
sum += i
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 45);
}
// test peeling more than one iteration
{
LoopsPeeler peeler(true_pred, 3);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 45);
}
}
TEST(LoopPeelerTest, LoopWithTerminationCondition) {
// tests with explicit termination conditions
static const auto str_func_def = R"JIT(
def test_with_cond_times():
sum = 0
i = 0
while (sum < 2):
sum += i
i += 1
return sum
)JIT";
// the peel changes the termination condition to false
// so the original loop doesn't run
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
auto stack = createStack({});
// peeling 5 iterations should update the termination
// condition to false
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
// the termination condition remains true
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
}
// tests simple nested loops
TEST(LoopPeelerTest, SimpleNestedLoops) {
static const auto str_func_def = R"JIT(
def test_nested_loops():
sum = 0
i = 0
for i in range(10):
for j in range(10):
sum += i + j
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 900);
}
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 900);
}
}
TEST(LoopPeelerTest, SimpleNestedLoops2) {
static const auto str_func_def = R"JIT(
def test_nested_loops():
sum = 0
i = 0
for i in range(10):
j = 0
while sum < 2:
sum += i + j
j += 1
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
}
TEST(JitTracing, Basic) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCPU);
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
auto graph = build_lstm();
auto stack = createStack({input, hx, cx, w_ih, w_hh});
auto traced = TraceGraph(graph, stack);
// Check that the inputs of traced graph have the same type as the inputs
// specified here.
ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
Tensor prof_out;
pop(stack, prof_out);
{
stack = createStack({input, hx, cx, w_ih, w_hh});
Code cd(traced, "traced");
InterpreterState is{cd};
is.run(stack);
Tensor traced_out;
pop(stack, traced_out);
torch::allclose(prof_out, traced_out);
}
{
stack = createStack({input, hx, cx, w_ih, w_hh});
Code cd(graph, "graph");
InterpreterState is{cd};
is.run(stack);
Tensor scripted_out;
pop(stack, scripted_out);
torch::allclose(prof_out, scripted_out);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
static const auto basic_example = R"JIT(
def basic(x, y):
a = x + y
b = x * y
c = x + 1
d = a - c
e = b - c
return d + e
)JIT";
auto cu = compile(basic_example);
auto& fun = toGraphFunction(cu->get_function("basic"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
ProfilingRecord::removeProfileCounter(copy->block());
InsertGuards(copy);
auto nodes = copy->block()->nodes();
auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == prim::Guard;
});
ASSERT_NE(guard, nodes.end());
ASSERT_EQ(
guard->input()->type()->expectRef<TensorType>().sizes().size(),
c10::nullopt);
checkShape(*guard, {2, 3}, false);
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 12);
// now eliminate as many guards as possible
// we should be left with two guards on x and y's defs
EliminateRedundantGuards(copy);
num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 2);
}
TEST(InsertBailOutsTest, Basic) {
static const auto basic_example = R"JIT(
def basic_loop(x, y):
a = x + 1
b = y + 2
c = x + y + 3
for i in range(10):
a = a + b
# invariant
d = b * c
#
a = a - d
e = a + 4
return e
)JIT";
auto cu = compile(basic_example);
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
ProfilingRecord::removeProfileCounter(copy->block());
InsertGuards(copy);
EliminateRedundantGuards(copy);
auto nodes = copy->block()->nodes();
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 3);
InsertBailOuts(copy);
auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
ASSERT_EQ(num_guards, num_bailouts);
std::vector<Node*> bailouts(num_bailouts);
std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
for (auto blo : bailouts) {
ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
}
}
TEST(ProfilerTest, Basic) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCPU);
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
auto g = build_lstm();
auto stack = createStack({input, hx, cx, w_ih, w_hh});
auto& opt_graph = *g.get();
ArgumentSpecCreator arg_spec_creator(opt_graph);
ArgumentSpec spec =
arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
arg_spec_creator.specializeTypes(opt_graph, spec);
auto pr = ProfilingRecord::instrumentGraph(g);
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
// profiled types are stored as attributes and show up in the dump, e.g.
// Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
// requires_grad=0, device=cpu)
testing::FileCheck()
.check("Tensor = prim::profile[profiled_type")
->check_same("256")
->run(*pr->profiled_graph_);
auto begin = pr->profiled_graph_->block()->nodes().begin();
auto end = pr->profiled_graph_->block()->nodes().end();
auto mm =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
ASSERT_NE(mm, end);
std::vector<int64_t> mm_expected{4, 2048};
std::vector<int64_t> eltwise{4, 512};
checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
auto mul_n =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
ASSERT_NE(mul_n, end);
checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
auto tanh_n =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
}
TEST(ProfilerTest, OptionalProfiling) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%inp : Tensor,
%weight : Tensor,
%bias : Tensor?):
%1 : Tensor = aten::linear(%inp, %weight, %bias)
return (%1))IR",
&*graph,
vmap);
auto pr = ProfilingRecord::instrumentGraph(graph);
pr->profiling_count_ = 2;
auto input = torch::randn({1, 2});
auto weight = torch::randn({2, 2});
auto bias = torch::randn({1, 2});
auto stack = createStack({input, weight, bias});
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
testing::FileCheck()
.check_count("Tensor? = prim::profile[profiled_type", 1, true)
->run(*pr->profiled_graph_);
// make sure we recorded the shape
auto begin = pr->profiled_graph_->block()->nodes().begin();
auto end = pr->profiled_graph_->block()->nodes().end();
auto linear = std::find_if(
begin, end, [](Node* n) { return n->kind() == aten::linear; });
ASSERT_NE(linear, end);
std::vector<int64_t> bias_expected_shape = {1, 2};
auto profiled_bias = linear->namedInput("bias")->node();
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
auto none_bias = c10::IValue();
stack.clear();
stack.emplace_back(input);
stack.emplace_back(weight);
stack.emplace_back(none_bias);
is = InterpreterState{cd};
is.run(stack);
// make sure we recorded that "None" was seen.
begin = pr->profiled_graph_->block()->nodes().begin();
end = pr->profiled_graph_->block()->nodes().end();
linear = std::find_if(
begin, end, [](Node* n) { return n->kind() == aten::linear; });
ASSERT_NE(linear, end);
profiled_bias = linear->namedInput("bias")->node();
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
}
TEST(CallStackTest, Basic) {
const auto text = R"(
def ham(x):
return x/7
def bar(x):
return x*3
def baz(x):
return ham(x)*x
def foo(x):
return bar(x)*baz(x)*11
)";
auto cu = compile(text);
const auto& foo = toGraphFunction(cu->get_function("foo"));
for (Node* n : foo.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
switch (v) {
case 3: {
// Const 3 comes from function 'bar', which gets inlined to 'foo'.
// The callstack for the corresponding node should contain only the
// function 'bar'.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
break;
}
case 7: {
// Const 7 comes from function 'ham', which gets inlined to 'baz',
// which is then inlined to 'foo'. The callstack for the corresponding
// node should contain these two functions.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 2);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
break;
}
case 11: {
// Const 11 comes from function 'foo', which is not inlined anywhere
// and thus it should not have a callstack.
ASSERT_FALSE(n->callstack());
break;
}
}
}
}
// Check that inlining doesn't corrupt callstack of the callee's nodes.
const auto& baz = toGraphFunction(cu->get_function("baz"));
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
ASSERT_TRUE(v == 7);
// Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
// was also inlined into 'foo', but when looking at the graph of 'baz' we
// should only see a callstack of depth 1 (containing only 'ham').
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
}
}
}
TEST(CallStackTest, Caching) {
const auto text = R"(
def a(x):
print("a1")
print("a2")
return x
def b(x):
print("b1")
print("b2")
a(x)
return x
def c(x):
print("c1")
print("c2")
b(x)
return x
)";
auto cu = compile(text);
const auto& baz = toGraphFunction(cu->get_function("c"));
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::s) {
continue;
}
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
std::string v = n->s(attr::value);
if (n->callstack()) {
callstack_objects[v] = n->callstack()->get();
}
}
}
// We expect to see nodes prim::Constant[value="a1"] and
// prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
// the same (a->b->c), so we want to make sure we're not creating different
// callstack entries for them.
ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
}
TEST(InlinedCallStackTest, BlockAnnotation) {
Module a("A");
a.define(R"(
def forward(self, x, y, z: int):
if (z == 1):
return x + y
else:
return x * y
)");
Module b("B");
b.define(R"(
def forward(self, x):
return x + 2
)");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"(
def forward(self, x, y, z: int):
return self.A0.forward(x, y, z) + self.B0.forward(x)
)");
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::stringstream add_ss, mul_ss;
for (Node* n : graph->nodes()) {
if (n->kind() == prim::If) {
for (Block* block : n->blocks()) {
for (Node* if_node : block->nodes()) {
if (if_node->kind() == aten::add) {
for (const auto& e : if_node->callstack().value()->vec()) {
add_ss << std::get<1>(e);
}
add_ss << if_node->sourceRange();
}
if (if_node->kind() == aten::mul) {
for (const auto& e : if_node->callstack().value()->vec()) {
mul_ss << std::get<1>(e);
}
mul_ss << if_node->sourceRange();
}
}
}
}
}
ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
ASSERT_NE(
add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
ASSERT_NE(
mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest, SelfCallMethods) {
Module a("A");
a.define(R"(
def my_new_method(self, x):
return x * 3
def forward_impl_(self, x, y):
return self.my_new_method(x) + y
def forward(self, x, y):
y = y + 2
return self.forward_impl_(x, y)
)");
Module b("B");
b.define(R"(
def forward(self, x):
return x + 2
)");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"(
def call_b(self, x):
return self.B0.forward(x)
def forward(self, x, y):
return self.A0.forward(x, y) + self.call_b(x)
)");
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::unordered_map<std::string, size_t> module_hierarchies;
for (Node* n : graph->nodes()) {
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
if (module_hierarchies.count(hierarchy) == 0) {
module_hierarchies[hierarchy] = 0;
}
module_hierarchies[hierarchy] += 1;
}
ASSERT_EQ(module_hierarchies["A0(A)"], 2);
ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
}
TEST(AutogradSymbolsTest, Basic) {
Symbol sym = Symbol::fromQualString("aten::test_symbol");
Graph graph;
auto node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::test_symbol");
node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::FusionGroup");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
sym = Symbol::fromQualString("custom::test_symbol");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
}
TEST(DefaultArgTypeHintingTest, Basic) {
const auto text_non_hinted = R"(
def a(x, y=1):
print("a1")
print("a2")
return x
)";
const auto text_hinted = R"(
def a(x, y:int=1):
print("a1")
print("a2")
return x
)";
try {
compile(text_non_hinted);
ASSERT_TRUE(0);
} catch (const std::exception& c) {
}
auto cu = compile(text_hinted);
}
// Basic set case.
TEST(FuturesTest, Basic) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
ASSERT_FALSE(f1->completed());
ASSERT_FALSE(f1->hasValue());
int32_t sat1 = 0;
int32_t sat2 = 0;
f1->addCallback([&](Future& /* unused */) { ++sat1; });
f1->markCompleted(43);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasValue());
ASSERT_FALSE(f1->hasError());
ASSERT_EQ(sat1, 1);
ASSERT_EQ(f1->constValue().toInt(), 43);
ASSERT_EQ(f1->value().toInt(), 43);
f1->addCallback([&](Future& /* unused */) { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
}
// Basic error cases.
TEST(FuturesTest, Error) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
int sat1 = 0;
int sat2 = 0;
f1->addCallback([&](Future& /* unused */) { ++sat1; });
f1->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_EQ(sat1, 1);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasError());
ASSERT_FALSE(f1->hasValue());
try {
(void)f1->value();
ASSERT_TRUE(false); // Supposed to throw.
} catch (const std::exception& e) {
ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
}
f1->addCallback([&](Future& /* unused */) { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
f1->setErrorIfNeeded(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
try {
(void)f1->constValue();
ASSERT_TRUE(false); // Supposed to throw.
} catch (const std::exception& e) {
// Original error should be logged.
ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
}
}
// then
TEST(FuturesTest, Then) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
auto f2 = f1->then(
[](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
IntType::get());
auto f3 = f2->then(
[](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
IntType::get());
bool done = false;
f3->addCallback([&done](Future& f3) {
ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
done = true;
});
ASSERT_FALSE(done);
f1->markCompleted(42);
ASSERT_TRUE(done);
}
// collectAll()
TEST(FuturesTest, CollectAll) {
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto s2 = c10::make_intrusive<Future>(IntType::get());
auto s3 = c10::make_intrusive<Future>(IntType::get());
// Empty case
c10::List<intrusive_ptr<ivalue::Future>> futures(
FutureType::create(IntType::get()));
auto c1 = collectAll(futures);
ASSERT_TRUE(c1->completed());
ASSERT_EQ(c1->value().toList().size(), 0);
ASSERT_TRUE(
*(c1->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
// 1-element, initially not completed.
futures.push_back(s1);
auto c2 = collectAll(futures);
ASSERT_FALSE(c2->completed());
s1->markCompleted(5);
ASSERT_TRUE(c2->completed());
ASSERT_EQ(c2->value().toList().size(), 1);
ASSERT_TRUE(
*(c2->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
// 1-element, already completed
auto c3 = collectAll(futures);
ASSERT_TRUE(c3->completed());
ASSERT_EQ(c3->value().toList().size(), 1);
ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
// 3 elements.
futures.push_back(s2);
futures.push_back(s3);
auto c4 = collectAll(futures);
ASSERT_FALSE(c4->completed());
s3->markCompleted(7);
ASSERT_FALSE(c4->completed());
s2->markCompleted(6);
ASSERT_TRUE(c4->completed());
ASSERT_EQ(c4->value().toList().size(), 3);
ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
ASSERT_TRUE(
*(c4->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
// Handle exception in the list.
auto s4 = c10::make_intrusive<Future>(IntType::get());
futures.push_back(s4);
auto c5 = collectAll(futures);
ASSERT_FALSE(c5->completed());
s4->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_TRUE(c5->completed());
try {
c5->value();
ASSERT_TRUE(false); // supposed to throw
} catch (const std::exception& e) {
ASSERT_EQ(std::string(e.what()), "Failed");
}
}
// collectAny()
TEST(FuturesTest, CollectAny) {
auto s1 = c10::make_intrusive<Future>(IntType::get());
// Empty case
c10::List<intrusive_ptr<ivalue::Future>> futures(
FutureType::create(IntType::get()));
auto c1 = collectAny(futures);
ASSERT_TRUE(c1->completed());
// 1 element, not yet satisfied
futures.push_back(s1);
auto c2 = collectAny(futures);
ASSERT_FALSE(c2->completed());
s1->markCompleted(5);
ASSERT_TRUE(c2->completed());
ASSERT_TRUE(c2->value().isInt());
ASSERT_EQ(c2->value().toInt(), 5);
// 1 element already satisfied.
auto c3 = collectAny(futures);
ASSERT_TRUE(c3->completed());
ASSERT_TRUE(c3->value().isInt());
ASSERT_EQ(c3->value().toInt(), 5);
// 2 elements
futures.clear();
auto s2 = c10::make_intrusive<Future>(IntType::get());
auto s3 = c10::make_intrusive<Future>(IntType::get());
futures.push_back(s2);
futures.push_back(s3);
auto c4 = collectAny(futures);
ASSERT_FALSE(c4->completed());
s3->markCompleted(7);
ASSERT_TRUE(c4->completed());
ASSERT_EQ(c4->value().toInt(), 7);
s2->markCompleted(1);
ASSERT_EQ(c4->value().toInt(), 7);
}
TEST(TLSFutureCallbacksTest, Basic) {
// cb that verifies the profiler is enabled
auto profilerEnabledCb = [](Future& /* unused */) {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
// test running callbacks with propagation of TLS state.
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
// Since we join here, we can ensure that all callbacks corresponding to
// markCompleted() have finished.
t.join();
torch::autograd::profiler::disableProfilerLegacy();
}
// then() with TLS State
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto s2 = s1->then(
wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
profilerEnabledCb(s1);
return at::IValue(1);
}),
IntType::get());
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
t.join();
s2->wait();
torch::autograd::profiler::disableProfilerLegacy();
}
}
TEST(ProfilerDisableInCallbackTest, Basic) {
// cb that verifies the profiler is enabled
auto profilerEnabledCb = []() {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto verifyProfilerCb =
wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
// Ensure the profiler is still enabled in this thread.
profilerEnabledCb();
auto t1 = torch::ones({2, 2});
auto t2 = torch::ones({2, 2});
torch::add(t1, t2);
// Don't cleanup TLSState, and just consolidate.
auto opts =
torch::autograd::profiler::ProfilerDisableOptions(false, true);
auto thread_event_lists =
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
// Ensure that the events from this thread are still profiled and we
// obtain the expected in events in our consolidated list when calling
// disableProfilerLegacy().
bool found_ones = false;
bool found_add = false;
for (const auto& li : thread_event_lists) {
for (const auto& evt : li) {
if (strcmp(evt.name(), "aten::add") == 0) {
found_add = true;
} else if (strcmp(evt.name(), "aten::ones") == 0) {
found_ones = true;
}
}
if (found_add && found_ones) {
break;
}
}
ASSERT_TRUE(found_ones);
ASSERT_TRUE(found_add);
});
s1->addCallback(verifyProfilerCb);
// Disable the profiler, but do not consolidate results in the main thread.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
t.join();
// Similar to above test, but verifies correctness in the case where
// continuation runs on the main thread.
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
s1 = c10::make_intrusive<Future>(IntType::get());
s1->addCallback(verifyProfilerCb);
// Runs callback inline
s1->markCompleted(at::IValue(1));
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
}
TEST(RecordDebugHandles, Basic) {
// Enable the profiler in this thread
const std::set<torch::autograd::profiler::ActivityType> activities(
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
activities);
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
activities);
{
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
float x{5.9999}, y{2.1212};
float z = x / y;
(void)z;
}
{
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
float x{5.9999}, y{2.1212};
float z = x / y;
(void)z;
}
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
const auto& kineto_events = profiler_results_ptr->events();
size_t my_events{0};
for (const auto& e : kineto_events) {
if (e.name() == "my_function") {
ASSERT_EQ(e.debugHandle(), 42);
my_events++;
} else if (e.name() == "not_my_function") {
ASSERT_EQ(e.debugHandle(), -1);
my_events++;
}
}
ASSERT_EQ(my_events, 2);
}
TEST(RecordDebugHandles, ScopedCallbacks) {
// Enable the profiler in this thread
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
{
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
// Enable the profiler in this thread
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU},
{at::RecordScope::LITE_INTERPRETER});
{
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU},
{at::RecordScope::LITE_INTERPRETER});
{
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
{
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
const auto& kineto_events = profiler_results_ptr->events();
for (const auto& e : kineto_events) {
if (e.name() == "my_function") {
ASSERT_EQ(e.debugHandle(), 42);
}
}
ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
}
TEST(IValueKWargsTest, Basic) {
const auto text = R"(
def foo(a : int, b : int, c : int = 4):
return a + 2*b + 3*c
)";
auto cu = compile(text);
auto result = cu->get_function("foo")({1}, {{"b", 3}});
ASSERT_EQ(result.toInt(), 19);
}
TEST(ComputeFlopsTest, Basic) {
uint64_t flops = 0;
// Test unknown operator
std::unordered_map<std::string, c10::IValue> extra_args;
flops = torch::profiler::impl::computeFlops(
std::string("aten::unknown"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d
extra_args.clear();
std::vector<int64_t> input_size = {4, 5, 6, 7};
std::vector<int64_t> weight_size = {3, 5, 2, 1};
std::vector<int64_t> padding = {1, 0};
std::vector<int64_t> stride = {1, 1};
std::vector<int64_t> dilation = {0, 0};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
extra_args["groups"] = 1;
extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 13440);
// Test aten::conv2d fail
input_size = {4, 5, 6, 7};
weight_size = {4, 5, 6};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d fail 2
weight_size = {3, 5, 2, 1};
stride = {0, 0};
extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d fail 3
extra_args.clear();
input_size = {4, 5, 6, 7};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::mm
extra_args.clear();
std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
ASSERT_EQ(flops, 43200);
// Test aten::addmm
flops = torch::profiler::impl::computeFlops(
std::string("aten::addmm"), extra_args);
ASSERT_EQ(flops, 43200);
// Test aten::bmm
extra_args.clear();
mat1_sizes = {7, 5, 6};
mat2_sizes = {7, 6, 3};
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
ASSERT_EQ(flops, 1260);
// Test aten::baddbmm
flops = torch::profiler::impl::computeFlops(
std::string("aten::baddbmm"), extra_args);
ASSERT_EQ(flops, 1260);
// Test mm out of range
extra_args.clear();
flops =
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::add.Tensor
extra_args.clear();
std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
ASSERT_EQ(flops, 360);
// Test aten::mul.Tensor
extra_args.clear();
mat_sizes = {3, 4, 5, 6};
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
ASSERT_EQ(flops, 360);
}
TEST(TestConstant, TensorGrad) {
auto graph = std::make_shared<Graph>();
IValue ten = torch::randn({3, 5}).requires_grad_(true);
auto con = tryInsertConstant(*graph, ten);
ASSERT_TRUE(con == c10::nullopt);
}
TEST(TestMutation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=4]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%7 : Tensor = aten::add_(%x.3, %2, %2)
%y.1 : Tensor = aten::add(%x.3, %9, %2)
return (%y.1))IR",
&*graph,
vmap);
RemoveTensorMutation(graph, [](Node*) { return false; });
testing::FileCheck().check("aten::add_")->run(*graph);
RemoveTensorMutation(graph, [](Node*) { return true; });
testing::FileCheck().check_not("aten::add_")->run(*graph);
}
TEST(TestInplaceToFunctionalActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu_(%x.3)
return (%y))IR",
&*graph,
vmap);
InplaceToFunctionalActivation(graph);
testing::FileCheck().check("aten::relu")->run(*graph);
testing::FileCheck().check_not("aten::relu_")->run(*graph);
}
TEST(TestRegisterShapeOp, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%2 : int = prim::Constant[value=5]()
%3: int[] = prim::ListConstruct(%2, %2)
return (%3))IR",
&*graph,
vmap);
auto g2 = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%2 : Tensor = prim::MakeTestTensor()
return (%2))IR",
&*g2,
vmap);
const FunctionSchema& schema = g2->nodes().begin()->schema();
torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
PropagateShapesOnGraph(g2);
testing::FileCheck().check("5, 5")->run(*g2);
}
TEST(TestFunctionalToInplaceActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu(%x.3)
return (%y))IR",
&*graph,
vmap);
FunctionalToInplaceActivation(graph);
testing::FileCheck().check("aten::relu_")->run(*graph);
testing::FileCheck().check_not("aten::relu(")->run(*graph);
}
TEST(TestFunctionExecutor, SimpleExecutorTest) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu(%x.3)
return (%y))IR",
&*graph);
{
auto func = torch::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};
func->run(stack);
auto g = lastExecutedOptimizedGraph();
testing::FileCheck()
.check("prim::profile")
->check("aten::add")
->check("aten::relu")
->run(*g);
}
{
auto func = torch::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};
func->run(stack);
auto g = func->getDebugState().graph;
testing::FileCheck()
.check_not("prim::profile")
->check("aten::add")
->check("aten::relu")
->run(*g);
}
}
TEST(TestFunctionExecutor, RunDecompositionTest) {
static auto* func = torch::jit::GetDecompositionExecutor(
"aten::var(Tensor self, bool unbiased=True) -> Tensor");
for (bool unbiased : {true, false}) {
auto input = at::rand({4, 4});
Stack stack = {input, unbiased};
func->run(stack);
at::Tensor out = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
}
}
TEST(TestShapeGraphLinting, Basic) {
auto schemas = RegisteredShapeComputeSchemas();
for (const auto& schema : schemas) {
// arange does not acually support complex, leave as
// union[int, float] for now
if (schema->name() == "aten::arange") {
continue;
}
auto g = shapeComputeGraphForSchema(*schema);
TORCH_INTERNAL_ASSERT(g);
LintShapeComputeGraph(schema, *g);
}
}
// TODO: move to test_kernel when global settings are explicit
// fusion parameters
class Composed : public ::testing::Test {
public:
void SetUp() override {
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
}
};
TEST_F(Composed, ComposedOp) {
struct WithCPUFuser {
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
}
bool cpuFuserEnabled;
};
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
%1 : Float(5, 3, strides=[1, 5], device=cpu)):
%2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
%3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
%4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
return (%3, %4))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
// wrong input sizes so we hit the fallback path
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
.transpose(0, 1);
auto ref1 = a * (a * b);
auto ref2 = a * ref1;
WithCPUFuser g(true);
bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
FuseTensorExprs(
graph,
/*min_group_size*/ 2,
/*add_composed_op*/ true,
/*fuse_to_dynamic_shapes*/ true);
Code code(graph, "");
InterpreterState interpreter{code};
std::vector<IValue> stack = {a, b};
interpreter.run(stack);
at::Tensor out2 = pop(stack).toTensor();
at::Tensor out1 = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(ref1, out1));
ASSERT_TRUE(at::allclose(ref2, out2));
auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
stack = {inp_1, inp_2, a, b};
InterpreterState interpreter2{code};
interpreter2.run(stack);
out2 = pop(stack).toTensor();
out1 = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(ref1, out1));
ASSERT_TRUE(at::allclose(ref2, out2));
// inp_1 is on the bottom of the stack, and corresponds
// to the second output. inp_2 is on the top corresponds to first output
ASSERT_TRUE(at::allclose(inp_1, ref2));
ASSERT_TRUE(at::allclose(inp_2, ref1));
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
#endif
}
TEST(ConstantPropagation, CustomClassesCanBePropagated) {
const auto src = R"IR(
graph():
%none: NoneType = prim::Constant()
%dim: int = prim::Constant[value=3]()
%shape: int[] = prim::ListConstruct(%dim, %dim)
%weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
%scale: float = prim::Constant[value=1.]()
%zero_point: int = prim::Constant[value=0]()
%dtype: int = prim::Constant[value=12]()
%weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
%params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
return (%params)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
ConstantPropagation(graph);
testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
}
} // namespace jit
} // namespace torch