[nnc] Get rid of fuser trigger counters (#57334)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57334

Here's a possibly controversial PR.  These counters got in the way of
generalizing the fuser tests to handle arbitrary devices, and I guess I'm just
generally skeptical that they provide much value.  While true that they let us
observe whether fusion groups were created, we already have assertions based on
the shape of the graph, and I'm not sure that I trust those any less than these
counters.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D29471484

Pulled By: bertmaher

fbshipit-source-id: f6d76f6e72dbfb581acff1d834b0c74500941b57
This commit is contained in:
Bert Maher 2021-06-29 22:19:46 -07:00 committed by Facebook GitHub Bot
parent c4f718cb72
commit 93772792e3
13 changed files with 0 additions and 271 deletions

View File

@ -36,7 +36,6 @@ class ParallelAdd : public benchmark::Fixture {
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
KernelScope kernel_scope;
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
Placeholder a_buf("a", kFloat, {M});
Placeholder b_buf("b", kFloat, {M});
Tensor* c_tensor = Compute(
@ -56,8 +55,6 @@ BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
float* c_ptr = C.data_ptr<float>();
std::vector<void*> args({c_ptr, a_ptr, b_ptr});
cg.value<int>(args);
int count = counter.elapsed_value();
TORCH_CHECK(count > 0);
for (int i = 0; i < M; i++) {
float diff = fabs(a_ptr[i] + b_ptr[i] - c_ptr[i]);
TORCH_CHECK(diff < 1e-5);

View File

@ -1584,7 +1584,6 @@ TEST(LLVM, SimpleParallel) {
for (int test_cfg = 0; test_cfg < 4; test_cfg++) {
// Compute a simple operation, and try all loop-axis combination to be
// parallel or sequential.
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
KernelScope kernel_scope;
const int M = 4;
const int N = 6;
@ -1617,12 +1616,6 @@ TEST(LLVM, SimpleParallel) {
}
}
ExpectAllNear(f_v, f_ref, 1e-5);
int count = counter.elapsed_value();
if (test_cfg == 0) {
ASSERT_EQ(count, 0);
} else {
ASSERT_GT(count, 0);
}
}
}
@ -1632,7 +1625,6 @@ TEST(LLVM, CompositeParallel) {
// Compute a composite operation, and try all loop-axis combination to be
// parallel or sequential.
for (int test_cfg = 0; test_cfg < test_count; test_cfg++) {
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
KernelScope kernel_scope;
int M = 5;
int N = 7;
@ -1693,12 +1685,6 @@ TEST(LLVM, CompositeParallel) {
}
}
ExpectAllNear(t4_v, t4_ref, 1e-5);
int count = counter.elapsed_value();
if (test_cfg == 0) {
ASSERT_EQ(count, 0);
} else {
ASSERT_GT(count, 0);
}
}
}

View File

@ -30,8 +30,6 @@ from itertools import product, permutations
from test_jit import backward_graph, get_lstm_inputs, get_milstm_inputs, \
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
from torch.testing._internal.te_utils import CudaCodeGenExecuted
from jit.test_fuser_common import TestFuserCommon # noqa: F401
FUSION_GROUP = 'prim::TensorExprGroup'
@ -913,9 +911,7 @@ class TestTEFuser(JitTestCase):
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
m = M()
out1 = m.create(x)
cx = CudaCodeGenExecuted()
out2 = m.create(x)
assert cx.elapsed_value() == 1
self.assertNotEqual(out1, out2)
self.assertTrue(torch.all(out1 >= 0))
self.assertTrue(torch.all(out1 < 1))
@ -994,9 +990,7 @@ class TestTEFuser(JitTestCase):
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
script_f = torch.jit.script(fn_test_diamond)
warmup_forward(script_f, x, y)
cx = CudaCodeGenExecuted()
out = script_f(x, y)
assert cx.elapsed_value() == 1
self.assertEqual(out, x + y)
def test_scalar(self):

View File

@ -6,9 +6,6 @@ import unittest
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \
LLVMCodeGenExecuted, SimpleIREvalExecuted
from torch.testing._internal.jit_utils import JitTestCase
@ -69,9 +66,6 @@ class TestTensorExprFuser(BaseTestClass):
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
def test_three_arg(self):
llvm_executed = LLVMCodeGenExecuted()
simple_ir_eval_executed = SimpleIREvalExecuted()
def easy(x, y, z):
aaa = torch.add(x, y)
bbb = torch.add(aaa, z)
@ -88,10 +82,6 @@ class TestTensorExprFuser(BaseTestClass):
self.assertLastGraphAllFused()
npr = a.numpy() + b.numpy() + c.numpy()
np.testing.assert_allclose(npr, x.numpy())
assert (
llvm_executed.elapsed_value() >= 1
or simple_ir_eval_executed.elapsed_value() >= 1
)
def test_four_arg(self):
def run_addcmul(x, y, z, w):
@ -1132,16 +1122,12 @@ class TestTensorExprFuser(BaseTestClass):
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
for test in (test_float, test_int):
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
x, y, z = [torch.rand(4) for i in range(3)]
a, b = 1, 2
test(x, y, z, a, b)
r = test(x, y, z, a, b)
xn, yn, zn = [t.numpy() for t in (x, y, z)]
np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b)
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
def test_loop(self):
@torch.jit.script
@ -1152,12 +1138,9 @@ class TestTensorExprFuser(BaseTestClass):
b = b + y
return b
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
test(x, y, z)
r = test(x, y, z)
assert llvm.elapsed_value == 1 or interp.elapsed_value() > 1
def test_slice(self):
def easy(x, y):
@ -1167,16 +1150,11 @@ class TestTensorExprFuser(BaseTestClass):
traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
a = torch.ones(1024, 1024)
x = traced(a, a)
npr = a[0:512:2]
npr = npr + npr
np.testing.assert_allclose(npr.numpy(), x.numpy())
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
def test_unsqueeze(self, N=256):
def easy(x, y):
@ -1186,16 +1164,11 @@ class TestTensorExprFuser(BaseTestClass):
traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
a = torch.rand(N, N)
x = traced(a, a)
npr = np.expand_dims(a, 0)
npr = npr + npr
np.testing.assert_allclose(npr, x.numpy())
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
def _test_softmax(self, device):
def test_softmax(x, y):
@ -1230,18 +1203,12 @@ class TestTensorExprFuser(BaseTestClass):
torch._C._jit_set_texpr_reductions_enabled(old)
def test_softmax_cpu(self):
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
self._test_softmax('cpu')
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
@unittest.skip("global allocs are not supported yet.")
def test_softmax_cuda(self):
cuda = CudaCodeGenExecuted()
self._test_softmax('cuda')
assert cuda.elapsed_value() == 1
def test_half_gelu(self):
devices = ["cuda"] if torch.cuda.is_available() else []
@ -1275,31 +1242,23 @@ class TestTensorExprFuser(BaseTestClass):
@torch.jit.script
def test(x, y, z):
return x.transpose(0, 1) + y + z
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
x = torch.rand(4, 5, 2, 3)
y = torch.rand(5, 4, 2, 3)
z = torch.rand(5, 4, 2, 3)
ref = test(x, y, z)
res = test(x, y, z)
np.testing.assert_allclose(ref.numpy(), res.numpy())
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
def test_sliced_stride(self):
@torch.jit.script
def test(x, y, z):
return x + y + z
llvm = LLVMCodeGenExecuted()
interp = SimpleIREvalExecuted()
x = torch.rand(16, 4, 2, 3)[::2]
y = torch.rand(8, 4, 2, 3)
z = torch.rand(8, 4, 2, 3)
ref = test(x, y, z)
res = test(x, y, z)
np.testing.assert_allclose(ref.numpy(), res.numpy())
# FIXME: interp.elapsed_value() also increments due to simplifier
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
@unittest.skip("dynamic shapes are not quite there yet")
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
@ -1308,13 +1267,11 @@ class TestTensorExprFuser(BaseTestClass):
@torch.jit.script
def test(x, y, z):
return x * y * z
cuda = CudaCodeGenCreated()
x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
ref = test(x, y, z)
_ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
res = test(x, y, z)
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
assert cuda.elapsed_value() == 1
# A wild broadcast appears.
x = torch.rand(4, 8).cuda()
@ -1323,7 +1280,6 @@ class TestTensorExprFuser(BaseTestClass):
res = test(x, y, z)
xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
assert cuda.elapsed_value() == 1
# Mismatched shapes shouldn't reach codegen.
x = torch.rand(4, 8).cuda()
@ -1333,7 +1289,6 @@ class TestTensorExprFuser(BaseTestClass):
res = test(x, y, z)
except RuntimeError as e:
assert "The size of tensor a (4) must match" in e.args[0]
assert cuda.elapsed_value() == 1
# Changing a static dimension fails guards.
# x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
@ -1341,22 +1296,16 @@ class TestTensorExprFuser(BaseTestClass):
# res = test(x, y, z)
# print(test.graph_for(x, y, z))
# np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
# assert cuda.elapsed_value() == 1
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
def test_guard_fails(self):
@torch.jit.script
def test(x, y, z):
return x * y * z
cuda = CudaCodeGenExecuted()
r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
assert cuda.elapsed_value() == 0
r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
assert cuda.elapsed_value() == 1
r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
assert cuda.elapsed_value() == 2
r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
assert cuda.elapsed_value() == 2
def test_bitwise_ops(self):
def run_and(x, y):

View File

@ -94,7 +94,6 @@
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
@ -653,14 +652,6 @@ void initJITBindings(PyObject* module) {
[](py::object obj) -> InferredType {
return tryToInferType(std::move(obj));
})
.def(
"_jit_get_trigger_value",
[](const std::string& trigger_name) -> int {
using namespace torch::jit::tensorexpr;
ExecutionTrigger* trigger =
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
return trigger->value();
})
.def(
"_jit_get_te_cuda_pointwise_loop_levels",
[]() -> int {

View File

@ -4,15 +4,12 @@
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
namespace torch {
namespace jit {
namespace tensorexpr {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_TRIGGER(block_codegen_created);
std::string blockDtypeCppString(const Dtype& dtype) {
switch (dtype.scalar_type()) {
case ScalarType::Bool:
@ -360,8 +357,6 @@ void BlockCodeGen::Initialize() {
stmt_v->accept(printer_.get());
GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n");
USE_TRIGGER(block_codegen_created);
}
void BlockCodeGen::call(const std::vector<CallArg>& args) {

View File

@ -10,7 +10,6 @@
#include <torch/csrc/jit/tensorexpr/cuda_random.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/registerizer.h>
@ -18,9 +17,6 @@ namespace torch {
namespace jit {
namespace tensorexpr {
DEFINE_TRIGGER(cuda_codegen_created);
DEFINE_TRIGGER(cuda_codegen_executed);
// A RAII wrapper to manage a variable and name pair in the look-up table.
// TODO: move this to a more shared place.
class ScopedVarName {
@ -1045,7 +1041,6 @@ void CudaCodeGen::Initialize() {
")");
CompileToNVRTC(oss_.str(), func_name);
USE_TRIGGER(cuda_codegen_created);
}
void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
@ -1147,7 +1142,6 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
stream,
ptr_to_args.data(),
nullptr));
USE_TRIGGER(cuda_codegen_executed);
if (prior_device != this->device().index()) {
at::cuda::set_device(prior_device);

View File

@ -8,9 +8,6 @@ namespace torch {
namespace jit {
namespace tensorexpr {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_TRIGGER(simple_ir_eval_executed);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
@ -993,7 +990,6 @@ void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
}
stmt()->accept(&*impl_);
impl_->clear();
USE_TRIGGER(simple_ir_eval_executed);
}
void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {

View File

@ -12,7 +12,6 @@
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
@ -23,9 +22,6 @@ namespace torch {
namespace jit {
namespace tensorexpr {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_TRIGGER(simple_ir_eval_executed);
class Value {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)

View File

@ -1,120 +0,0 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <iostream>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
namespace tensorexpr {
/*
ExecutionTrigger and ExecutionCounter builds instrumentation counters so
underlying functionalities can be checked.
In the code to be instrumented:
// worker.cpp
DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done"
void run() {
USE_TRIGGER(useful_work_done); // this triggers the underlying counter
// in "useful_work_done"
}
// in C++ client.cpp
DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that
// will be defined elsewhere
ExecutionCounter counter(useful_work_done); // This starts the counter from the
// underlying trigger.
... call run() ...
counter.elapsed_value(); // this returns the incremented value from the
// trigger since the creation of the counter
// in Python client.py
counter = ExecutionCounter("useful_work_done") // this starts the counter from
// the underlying trigger
... call C++ run() ...
counter.elapsed_value() // This returns the incremented value from the
// trigger since the creation of the counter.
*/
class ExecutionTrigger;
class ExecutionTriggerList {
public:
TORCH_API static ExecutionTriggerList& GetInstance() {
static ExecutionTriggerList instance;
return instance;
}
ExecutionTrigger* FindByName(const std::string& name) const {
auto iter = trigger_list_.find(name);
if (iter == trigger_list_.end()) {
throw std::runtime_error("Invalid trigger name: " + name);
}
return iter->second;
}
ExecutionTriggerList(const ExecutionTriggerList&) = delete;
ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete;
private:
friend class ExecutionTrigger;
ExecutionTriggerList() = default;
void AddTrigger(const std::string& name, ExecutionTrigger* trigger) {
auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger));
if (!insert_ret.second) {
std::cerr << "Warning: duplicated trigger name: " << name << "\n";
}
}
std::unordered_map<std::string, ExecutionTrigger*> trigger_list_;
};
class ExecutionTrigger {
public:
explicit ExecutionTrigger(const std::string& name) : name_(name) {
ExecutionTriggerList::GetInstance().AddTrigger(name, this);
}
ExecutionTrigger(const ExecutionTrigger&) = delete;
ExecutionTrigger& operator=(const ExecutionTrigger&) = delete;
int value() const {
return value_;
}
void trigger() {
value_++;
}
private:
int value_ = 0;
const std::string name_;
};
class ExecutionCounter {
public:
explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) {
start_value_ = trigger_.value();
}
int elapsed_value() const {
return trigger_.value() - start_value_;
}
private:
ExecutionTrigger& trigger_;
int start_value_ = 0;
};
#define DEFINE_TRIGGER(name) ExecutionTrigger name(#name)
#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name
#define USE_TRIGGER(name) (name).trigger()
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -33,7 +33,6 @@
#include <llvm/Support/TypeSize.h>
#endif
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
#include <torch/csrc/jit/tensorexpr/half_support.h>
@ -53,13 +52,9 @@ C10_DEFINE_bool(
false,
"Use fast (but slightly less accurate) implementations of tanh and sigmoid");
DEFINE_TRIGGER(llvm_codegen_created);
DEFINE_TRIGGER(llvm_codegen_executed);
namespace torch {
namespace jit {
namespace tensorexpr {
DEFINE_TRIGGER(llvm_codegen_parallel_dispatched);
namespace {
llvm::CmpInst::Predicate llvm_comparison_predicate(
@ -288,7 +283,6 @@ void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
callee(index, packed_data);
}
});
USE_TRIGGER(llvm_codegen_parallel_dispatched);
}
} // namespace tensorexpr
@ -315,7 +309,6 @@ LLVMCodeGen::LLVMCodeGen(
void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
value<float>(const_cast<void**>(args.data()));
USE_TRIGGER(llvm_codegen_executed);
}
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
@ -333,7 +326,6 @@ void LLVMCodeGen::call(const std::vector<CallArg>& args) {
argv[i] = argToPtr(bufferArg, callArg);
}
value<float>(argv.data());
USE_TRIGGER(llvm_codegen_executed);
}
at::Tensor LLVMCodeGen::empty_strided(
@ -438,8 +430,6 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
jit_->addModule(std::move(module_), std::move(context_));
auto sym = jit_->findSymbol("wrapper");
kernelAddress_ = assertSuccess(sym.getAddress());
USE_TRIGGER(llvm_codegen_created);
}
llvm::LLVMContext& LLVMCodeGenImpl::getContext() {

View File

@ -4,7 +4,6 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
@ -17,8 +16,6 @@ namespace torch {
namespace jit {
namespace tensorexpr {
DECLARE_TRIGGER(llvm_codegen_parallel_dispatched);
class LLVMCodeGenImpl;
class TORCH_API LLVMCodeGen : public CodeGen {

View File

@ -1,36 +0,0 @@
import torch
class ExecutionCounter(object):
def try_get_trigger_value(self):
try:
return torch._C._jit_get_trigger_value(self.name)
except Exception:
return 0
def __init__(self, name):
self.name = name
self.start_value = self.try_get_trigger_value()
def elapsed_value(self):
value = self.try_get_trigger_value()
return value - self.start_value
class CudaCodeGenCreated(ExecutionCounter):
def __init__(self):
super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
class CudaCodeGenExecuted(ExecutionCounter):
def __init__(self):
super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
class LLVMCodeGenCreated(ExecutionCounter):
def __init__(self):
super(LLVMCodeGenCreated, self).__init__("llvm_codegen_created")
class LLVMCodeGenExecuted(ExecutionCounter):
def __init__(self):
super(LLVMCodeGenExecuted, self).__init__("llvm_codegen_executed")
class SimpleIREvalExecuted(ExecutionCounter):
def __init__(self):
super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")