mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Get rid of fuser trigger counters (#57334)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57334 Here's a possibly controversial PR. These counters got in the way of generalizing the fuser tests to handle arbitrary devices, and I guess I'm just generally skeptical that they provide much value. While true that they let us observe whether fusion groups were created, we already have assertions based on the shape of the graph, and I'm not sure that I trust those any less than these counters. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D29471484 Pulled By: bertmaher fbshipit-source-id: f6d76f6e72dbfb581acff1d834b0c74500941b57
This commit is contained in:
parent
c4f718cb72
commit
93772792e3
|
|
@ -36,7 +36,6 @@ class ParallelAdd : public benchmark::Fixture {
|
||||||
|
|
||||||
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
|
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
|
||||||
KernelScope kernel_scope;
|
KernelScope kernel_scope;
|
||||||
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
|
|
||||||
Placeholder a_buf("a", kFloat, {M});
|
Placeholder a_buf("a", kFloat, {M});
|
||||||
Placeholder b_buf("b", kFloat, {M});
|
Placeholder b_buf("b", kFloat, {M});
|
||||||
Tensor* c_tensor = Compute(
|
Tensor* c_tensor = Compute(
|
||||||
|
|
@ -56,8 +55,6 @@ BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
|
||||||
float* c_ptr = C.data_ptr<float>();
|
float* c_ptr = C.data_ptr<float>();
|
||||||
std::vector<void*> args({c_ptr, a_ptr, b_ptr});
|
std::vector<void*> args({c_ptr, a_ptr, b_ptr});
|
||||||
cg.value<int>(args);
|
cg.value<int>(args);
|
||||||
int count = counter.elapsed_value();
|
|
||||||
TORCH_CHECK(count > 0);
|
|
||||||
for (int i = 0; i < M; i++) {
|
for (int i = 0; i < M; i++) {
|
||||||
float diff = fabs(a_ptr[i] + b_ptr[i] - c_ptr[i]);
|
float diff = fabs(a_ptr[i] + b_ptr[i] - c_ptr[i]);
|
||||||
TORCH_CHECK(diff < 1e-5);
|
TORCH_CHECK(diff < 1e-5);
|
||||||
|
|
|
||||||
|
|
@ -1584,7 +1584,6 @@ TEST(LLVM, SimpleParallel) {
|
||||||
for (int test_cfg = 0; test_cfg < 4; test_cfg++) {
|
for (int test_cfg = 0; test_cfg < 4; test_cfg++) {
|
||||||
// Compute a simple operation, and try all loop-axis combination to be
|
// Compute a simple operation, and try all loop-axis combination to be
|
||||||
// parallel or sequential.
|
// parallel or sequential.
|
||||||
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
|
|
||||||
KernelScope kernel_scope;
|
KernelScope kernel_scope;
|
||||||
const int M = 4;
|
const int M = 4;
|
||||||
const int N = 6;
|
const int N = 6;
|
||||||
|
|
@ -1617,12 +1616,6 @@ TEST(LLVM, SimpleParallel) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpectAllNear(f_v, f_ref, 1e-5);
|
ExpectAllNear(f_v, f_ref, 1e-5);
|
||||||
int count = counter.elapsed_value();
|
|
||||||
if (test_cfg == 0) {
|
|
||||||
ASSERT_EQ(count, 0);
|
|
||||||
} else {
|
|
||||||
ASSERT_GT(count, 0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1632,7 +1625,6 @@ TEST(LLVM, CompositeParallel) {
|
||||||
// Compute a composite operation, and try all loop-axis combination to be
|
// Compute a composite operation, and try all loop-axis combination to be
|
||||||
// parallel or sequential.
|
// parallel or sequential.
|
||||||
for (int test_cfg = 0; test_cfg < test_count; test_cfg++) {
|
for (int test_cfg = 0; test_cfg < test_count; test_cfg++) {
|
||||||
ExecutionCounter counter(llvm_codegen_parallel_dispatched);
|
|
||||||
KernelScope kernel_scope;
|
KernelScope kernel_scope;
|
||||||
int M = 5;
|
int M = 5;
|
||||||
int N = 7;
|
int N = 7;
|
||||||
|
|
@ -1693,12 +1685,6 @@ TEST(LLVM, CompositeParallel) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpectAllNear(t4_v, t4_ref, 1e-5);
|
ExpectAllNear(t4_v, t4_ref, 1e-5);
|
||||||
int count = counter.elapsed_value();
|
|
||||||
if (test_cfg == 0) {
|
|
||||||
ASSERT_EQ(count, 0);
|
|
||||||
} else {
|
|
||||||
ASSERT_GT(count, 0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,6 @@ from itertools import product, permutations
|
||||||
from test_jit import backward_graph, get_lstm_inputs, get_milstm_inputs, \
|
from test_jit import backward_graph, get_lstm_inputs, get_milstm_inputs, \
|
||||||
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
|
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
|
||||||
|
|
||||||
from torch.testing._internal.te_utils import CudaCodeGenExecuted
|
|
||||||
|
|
||||||
from jit.test_fuser_common import TestFuserCommon # noqa: F401
|
from jit.test_fuser_common import TestFuserCommon # noqa: F401
|
||||||
|
|
||||||
FUSION_GROUP = 'prim::TensorExprGroup'
|
FUSION_GROUP = 'prim::TensorExprGroup'
|
||||||
|
|
@ -913,9 +911,7 @@ class TestTEFuser(JitTestCase):
|
||||||
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
|
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
|
||||||
m = M()
|
m = M()
|
||||||
out1 = m.create(x)
|
out1 = m.create(x)
|
||||||
cx = CudaCodeGenExecuted()
|
|
||||||
out2 = m.create(x)
|
out2 = m.create(x)
|
||||||
assert cx.elapsed_value() == 1
|
|
||||||
self.assertNotEqual(out1, out2)
|
self.assertNotEqual(out1, out2)
|
||||||
self.assertTrue(torch.all(out1 >= 0))
|
self.assertTrue(torch.all(out1 >= 0))
|
||||||
self.assertTrue(torch.all(out1 < 1))
|
self.assertTrue(torch.all(out1 < 1))
|
||||||
|
|
@ -994,9 +990,7 @@ class TestTEFuser(JitTestCase):
|
||||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||||
script_f = torch.jit.script(fn_test_diamond)
|
script_f = torch.jit.script(fn_test_diamond)
|
||||||
warmup_forward(script_f, x, y)
|
warmup_forward(script_f, x, y)
|
||||||
cx = CudaCodeGenExecuted()
|
|
||||||
out = script_f(x, y)
|
out = script_f(x, y)
|
||||||
assert cx.elapsed_value() == 1
|
|
||||||
self.assertEqual(out, x + y)
|
self.assertEqual(out, x + y)
|
||||||
|
|
||||||
def test_scalar(self):
|
def test_scalar(self):
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,6 @@ import unittest
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
|
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
|
||||||
|
|
||||||
from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \
|
|
||||||
LLVMCodeGenExecuted, SimpleIREvalExecuted
|
|
||||||
|
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,9 +66,6 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
|
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
|
||||||
|
|
||||||
def test_three_arg(self):
|
def test_three_arg(self):
|
||||||
llvm_executed = LLVMCodeGenExecuted()
|
|
||||||
simple_ir_eval_executed = SimpleIREvalExecuted()
|
|
||||||
|
|
||||||
def easy(x, y, z):
|
def easy(x, y, z):
|
||||||
aaa = torch.add(x, y)
|
aaa = torch.add(x, y)
|
||||||
bbb = torch.add(aaa, z)
|
bbb = torch.add(aaa, z)
|
||||||
|
|
@ -88,10 +82,6 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
self.assertLastGraphAllFused()
|
self.assertLastGraphAllFused()
|
||||||
npr = a.numpy() + b.numpy() + c.numpy()
|
npr = a.numpy() + b.numpy() + c.numpy()
|
||||||
np.testing.assert_allclose(npr, x.numpy())
|
np.testing.assert_allclose(npr, x.numpy())
|
||||||
assert (
|
|
||||||
llvm_executed.elapsed_value() >= 1
|
|
||||||
or simple_ir_eval_executed.elapsed_value() >= 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_four_arg(self):
|
def test_four_arg(self):
|
||||||
def run_addcmul(x, y, z, w):
|
def run_addcmul(x, y, z, w):
|
||||||
|
|
@ -1132,16 +1122,12 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
|
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
|
||||||
|
|
||||||
for test in (test_float, test_int):
|
for test in (test_float, test_int):
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
x, y, z = [torch.rand(4) for i in range(3)]
|
x, y, z = [torch.rand(4) for i in range(3)]
|
||||||
a, b = 1, 2
|
a, b = 1, 2
|
||||||
test(x, y, z, a, b)
|
test(x, y, z, a, b)
|
||||||
r = test(x, y, z, a, b)
|
r = test(x, y, z, a, b)
|
||||||
xn, yn, zn = [t.numpy() for t in (x, y, z)]
|
xn, yn, zn = [t.numpy() for t in (x, y, z)]
|
||||||
np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b)
|
np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b)
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
def test_loop(self):
|
def test_loop(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
@ -1152,12 +1138,9 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
b = b + y
|
b = b + y
|
||||||
return b
|
return b
|
||||||
|
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
|
x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
|
||||||
test(x, y, z)
|
test(x, y, z)
|
||||||
r = test(x, y, z)
|
r = test(x, y, z)
|
||||||
assert llvm.elapsed_value == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
def test_slice(self):
|
def test_slice(self):
|
||||||
def easy(x, y):
|
def easy(x, y):
|
||||||
|
|
@ -1167,16 +1150,11 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
|
|
||||||
traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
|
traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
|
||||||
|
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
|
|
||||||
a = torch.ones(1024, 1024)
|
a = torch.ones(1024, 1024)
|
||||||
x = traced(a, a)
|
x = traced(a, a)
|
||||||
npr = a[0:512:2]
|
npr = a[0:512:2]
|
||||||
npr = npr + npr
|
npr = npr + npr
|
||||||
np.testing.assert_allclose(npr.numpy(), x.numpy())
|
np.testing.assert_allclose(npr.numpy(), x.numpy())
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
def test_unsqueeze(self, N=256):
|
def test_unsqueeze(self, N=256):
|
||||||
def easy(x, y):
|
def easy(x, y):
|
||||||
|
|
@ -1186,16 +1164,11 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
|
|
||||||
traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
|
traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
|
||||||
|
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
|
|
||||||
a = torch.rand(N, N)
|
a = torch.rand(N, N)
|
||||||
x = traced(a, a)
|
x = traced(a, a)
|
||||||
npr = np.expand_dims(a, 0)
|
npr = np.expand_dims(a, 0)
|
||||||
npr = npr + npr
|
npr = npr + npr
|
||||||
np.testing.assert_allclose(npr, x.numpy())
|
np.testing.assert_allclose(npr, x.numpy())
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
def _test_softmax(self, device):
|
def _test_softmax(self, device):
|
||||||
def test_softmax(x, y):
|
def test_softmax(x, y):
|
||||||
|
|
@ -1230,18 +1203,12 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
torch._C._jit_set_texpr_reductions_enabled(old)
|
torch._C._jit_set_texpr_reductions_enabled(old)
|
||||||
|
|
||||||
def test_softmax_cpu(self):
|
def test_softmax_cpu(self):
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
self._test_softmax('cpu')
|
self._test_softmax('cpu')
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
||||||
@unittest.skip("global allocs are not supported yet.")
|
@unittest.skip("global allocs are not supported yet.")
|
||||||
def test_softmax_cuda(self):
|
def test_softmax_cuda(self):
|
||||||
cuda = CudaCodeGenExecuted()
|
|
||||||
self._test_softmax('cuda')
|
self._test_softmax('cuda')
|
||||||
assert cuda.elapsed_value() == 1
|
|
||||||
|
|
||||||
def test_half_gelu(self):
|
def test_half_gelu(self):
|
||||||
devices = ["cuda"] if torch.cuda.is_available() else []
|
devices = ["cuda"] if torch.cuda.is_available() else []
|
||||||
|
|
@ -1275,31 +1242,23 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def test(x, y, z):
|
def test(x, y, z):
|
||||||
return x.transpose(0, 1) + y + z
|
return x.transpose(0, 1) + y + z
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
x = torch.rand(4, 5, 2, 3)
|
x = torch.rand(4, 5, 2, 3)
|
||||||
y = torch.rand(5, 4, 2, 3)
|
y = torch.rand(5, 4, 2, 3)
|
||||||
z = torch.rand(5, 4, 2, 3)
|
z = torch.rand(5, 4, 2, 3)
|
||||||
ref = test(x, y, z)
|
ref = test(x, y, z)
|
||||||
res = test(x, y, z)
|
res = test(x, y, z)
|
||||||
np.testing.assert_allclose(ref.numpy(), res.numpy())
|
np.testing.assert_allclose(ref.numpy(), res.numpy())
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
def test_sliced_stride(self):
|
def test_sliced_stride(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def test(x, y, z):
|
def test(x, y, z):
|
||||||
return x + y + z
|
return x + y + z
|
||||||
llvm = LLVMCodeGenExecuted()
|
|
||||||
interp = SimpleIREvalExecuted()
|
|
||||||
x = torch.rand(16, 4, 2, 3)[::2]
|
x = torch.rand(16, 4, 2, 3)[::2]
|
||||||
y = torch.rand(8, 4, 2, 3)
|
y = torch.rand(8, 4, 2, 3)
|
||||||
z = torch.rand(8, 4, 2, 3)
|
z = torch.rand(8, 4, 2, 3)
|
||||||
ref = test(x, y, z)
|
ref = test(x, y, z)
|
||||||
res = test(x, y, z)
|
res = test(x, y, z)
|
||||||
np.testing.assert_allclose(ref.numpy(), res.numpy())
|
np.testing.assert_allclose(ref.numpy(), res.numpy())
|
||||||
# FIXME: interp.elapsed_value() also increments due to simplifier
|
|
||||||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
|
||||||
|
|
||||||
@unittest.skip("dynamic shapes are not quite there yet")
|
@unittest.skip("dynamic shapes are not quite there yet")
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
||||||
|
|
@ -1308,13 +1267,11 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def test(x, y, z):
|
def test(x, y, z):
|
||||||
return x * y * z
|
return x * y * z
|
||||||
cuda = CudaCodeGenCreated()
|
|
||||||
x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
|
x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
|
||||||
ref = test(x, y, z)
|
ref = test(x, y, z)
|
||||||
_ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
|
_ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
|
||||||
res = test(x, y, z)
|
res = test(x, y, z)
|
||||||
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
|
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
|
||||||
assert cuda.elapsed_value() == 1
|
|
||||||
|
|
||||||
# A wild broadcast appears.
|
# A wild broadcast appears.
|
||||||
x = torch.rand(4, 8).cuda()
|
x = torch.rand(4, 8).cuda()
|
||||||
|
|
@ -1323,7 +1280,6 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
res = test(x, y, z)
|
res = test(x, y, z)
|
||||||
xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
|
xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
|
||||||
np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
||||||
assert cuda.elapsed_value() == 1
|
|
||||||
|
|
||||||
# Mismatched shapes shouldn't reach codegen.
|
# Mismatched shapes shouldn't reach codegen.
|
||||||
x = torch.rand(4, 8).cuda()
|
x = torch.rand(4, 8).cuda()
|
||||||
|
|
@ -1333,7 +1289,6 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
res = test(x, y, z)
|
res = test(x, y, z)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
assert "The size of tensor a (4) must match" in e.args[0]
|
assert "The size of tensor a (4) must match" in e.args[0]
|
||||||
assert cuda.elapsed_value() == 1
|
|
||||||
|
|
||||||
# Changing a static dimension fails guards.
|
# Changing a static dimension fails guards.
|
||||||
# x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
|
# x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
|
||||||
|
|
@ -1341,22 +1296,16 @@ class TestTensorExprFuser(BaseTestClass):
|
||||||
# res = test(x, y, z)
|
# res = test(x, y, z)
|
||||||
# print(test.graph_for(x, y, z))
|
# print(test.graph_for(x, y, z))
|
||||||
# np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
# np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
||||||
# assert cuda.elapsed_value() == 1
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
||||||
def test_guard_fails(self):
|
def test_guard_fails(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def test(x, y, z):
|
def test(x, y, z):
|
||||||
return x * y * z
|
return x * y * z
|
||||||
cuda = CudaCodeGenExecuted()
|
|
||||||
r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||||
assert cuda.elapsed_value() == 0
|
|
||||||
r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||||
assert cuda.elapsed_value() == 1
|
|
||||||
r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||||
assert cuda.elapsed_value() == 2
|
|
||||||
r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
|
r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
|
||||||
assert cuda.elapsed_value() == 2
|
|
||||||
|
|
||||||
def test_bitwise_ops(self):
|
def test_bitwise_ops(self):
|
||||||
def run_and(x, y):
|
def run_and(x, y):
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,6 @@
|
||||||
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
||||||
#include <torch/csrc/jit/serialization/export.h>
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
#include <torch/csrc/jit/serialization/import.h>
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
||||||
|
|
||||||
|
|
@ -653,14 +652,6 @@ void initJITBindings(PyObject* module) {
|
||||||
[](py::object obj) -> InferredType {
|
[](py::object obj) -> InferredType {
|
||||||
return tryToInferType(std::move(obj));
|
return tryToInferType(std::move(obj));
|
||||||
})
|
})
|
||||||
.def(
|
|
||||||
"_jit_get_trigger_value",
|
|
||||||
[](const std::string& trigger_name) -> int {
|
|
||||||
using namespace torch::jit::tensorexpr;
|
|
||||||
ExecutionTrigger* trigger =
|
|
||||||
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
|
|
||||||
return trigger->value();
|
|
||||||
})
|
|
||||||
.def(
|
.def(
|
||||||
"_jit_get_te_cuda_pointwise_loop_levels",
|
"_jit_get_te_cuda_pointwise_loop_levels",
|
||||||
[]() -> int {
|
[]() -> int {
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,12 @@
|
||||||
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
DEFINE_TRIGGER(block_codegen_created);
|
|
||||||
std::string blockDtypeCppString(const Dtype& dtype) {
|
std::string blockDtypeCppString(const Dtype& dtype) {
|
||||||
switch (dtype.scalar_type()) {
|
switch (dtype.scalar_type()) {
|
||||||
case ScalarType::Bool:
|
case ScalarType::Bool:
|
||||||
|
|
@ -360,8 +357,6 @@ void BlockCodeGen::Initialize() {
|
||||||
stmt_v->accept(printer_.get());
|
stmt_v->accept(printer_.get());
|
||||||
|
|
||||||
GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n");
|
GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n");
|
||||||
|
|
||||||
USE_TRIGGER(block_codegen_created);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockCodeGen::call(const std::vector<CallArg>& args) {
|
void BlockCodeGen::call(const std::vector<CallArg>& args) {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@
|
||||||
#include <torch/csrc/jit/tensorexpr/cuda_random.h>
|
#include <torch/csrc/jit/tensorexpr/cuda_random.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/registerizer.h>
|
#include <torch/csrc/jit/tensorexpr/registerizer.h>
|
||||||
|
|
||||||
|
|
@ -18,9 +17,6 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
||||||
DEFINE_TRIGGER(cuda_codegen_created);
|
|
||||||
DEFINE_TRIGGER(cuda_codegen_executed);
|
|
||||||
|
|
||||||
// A RAII wrapper to manage a variable and name pair in the look-up table.
|
// A RAII wrapper to manage a variable and name pair in the look-up table.
|
||||||
// TODO: move this to a more shared place.
|
// TODO: move this to a more shared place.
|
||||||
class ScopedVarName {
|
class ScopedVarName {
|
||||||
|
|
@ -1045,7 +1041,6 @@ void CudaCodeGen::Initialize() {
|
||||||
")");
|
")");
|
||||||
|
|
||||||
CompileToNVRTC(oss_.str(), func_name);
|
CompileToNVRTC(oss_.str(), func_name);
|
||||||
USE_TRIGGER(cuda_codegen_created);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
|
void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
|
||||||
|
|
@ -1147,7 +1142,6 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
|
||||||
stream,
|
stream,
|
||||||
ptr_to_args.data(),
|
ptr_to_args.data(),
|
||||||
nullptr));
|
nullptr));
|
||||||
USE_TRIGGER(cuda_codegen_executed);
|
|
||||||
|
|
||||||
if (prior_device != this->device().index()) {
|
if (prior_device != this->device().index()) {
|
||||||
at::cuda::set_device(prior_device);
|
at::cuda::set_device(prior_device);
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,6 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
DEFINE_TRIGGER(simple_ir_eval_executed);
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
|
RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
|
||||||
|
|
||||||
|
|
@ -993,7 +990,6 @@ void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
|
||||||
}
|
}
|
||||||
stmt()->accept(&*impl_);
|
stmt()->accept(&*impl_);
|
||||||
impl_->clear();
|
impl_->clear();
|
||||||
USE_TRIGGER(simple_ir_eval_executed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
|
void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@
|
||||||
#include <c10/util/string_utils.h>
|
#include <c10/util/string_utils.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||||
|
|
@ -23,9 +22,6 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
DECLARE_TRIGGER(simple_ir_eval_executed);
|
|
||||||
|
|
||||||
class Value {
|
class Value {
|
||||||
public:
|
public:
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||||
|
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
namespace jit {
|
|
||||||
namespace tensorexpr {
|
|
||||||
|
|
||||||
/*
|
|
||||||
ExecutionTrigger and ExecutionCounter builds instrumentation counters so
|
|
||||||
underlying functionalities can be checked.
|
|
||||||
|
|
||||||
In the code to be instrumented:
|
|
||||||
|
|
||||||
// worker.cpp
|
|
||||||
DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done"
|
|
||||||
void run() {
|
|
||||||
USE_TRIGGER(useful_work_done); // this triggers the underlying counter
|
|
||||||
// in "useful_work_done"
|
|
||||||
}
|
|
||||||
|
|
||||||
// in C++ client.cpp
|
|
||||||
|
|
||||||
DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that
|
|
||||||
// will be defined elsewhere
|
|
||||||
ExecutionCounter counter(useful_work_done); // This starts the counter from the
|
|
||||||
// underlying trigger.
|
|
||||||
... call run() ...
|
|
||||||
counter.elapsed_value(); // this returns the incremented value from the
|
|
||||||
// trigger since the creation of the counter
|
|
||||||
|
|
||||||
// in Python client.py
|
|
||||||
counter = ExecutionCounter("useful_work_done") // this starts the counter from
|
|
||||||
// the underlying trigger
|
|
||||||
... call C++ run() ...
|
|
||||||
counter.elapsed_value() // This returns the incremented value from the
|
|
||||||
// trigger since the creation of the counter.
|
|
||||||
*/
|
|
||||||
|
|
||||||
class ExecutionTrigger;
|
|
||||||
class ExecutionTriggerList {
|
|
||||||
public:
|
|
||||||
TORCH_API static ExecutionTriggerList& GetInstance() {
|
|
||||||
static ExecutionTriggerList instance;
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutionTrigger* FindByName(const std::string& name) const {
|
|
||||||
auto iter = trigger_list_.find(name);
|
|
||||||
if (iter == trigger_list_.end()) {
|
|
||||||
throw std::runtime_error("Invalid trigger name: " + name);
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutionTriggerList(const ExecutionTriggerList&) = delete;
|
|
||||||
ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class ExecutionTrigger;
|
|
||||||
|
|
||||||
ExecutionTriggerList() = default;
|
|
||||||
|
|
||||||
void AddTrigger(const std::string& name, ExecutionTrigger* trigger) {
|
|
||||||
auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger));
|
|
||||||
if (!insert_ret.second) {
|
|
||||||
std::cerr << "Warning: duplicated trigger name: " << name << "\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_map<std::string, ExecutionTrigger*> trigger_list_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ExecutionTrigger {
|
|
||||||
public:
|
|
||||||
explicit ExecutionTrigger(const std::string& name) : name_(name) {
|
|
||||||
ExecutionTriggerList::GetInstance().AddTrigger(name, this);
|
|
||||||
}
|
|
||||||
ExecutionTrigger(const ExecutionTrigger&) = delete;
|
|
||||||
ExecutionTrigger& operator=(const ExecutionTrigger&) = delete;
|
|
||||||
|
|
||||||
int value() const {
|
|
||||||
return value_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void trigger() {
|
|
||||||
value_++;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int value_ = 0;
|
|
||||||
const std::string name_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ExecutionCounter {
|
|
||||||
public:
|
|
||||||
explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) {
|
|
||||||
start_value_ = trigger_.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
int elapsed_value() const {
|
|
||||||
return trigger_.value() - start_value_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ExecutionTrigger& trigger_;
|
|
||||||
int start_value_ = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
#define DEFINE_TRIGGER(name) ExecutionTrigger name(#name)
|
|
||||||
#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name
|
|
||||||
#define USE_TRIGGER(name) (name).trigger()
|
|
||||||
|
|
||||||
} // namespace tensorexpr
|
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
|
@ -33,7 +33,6 @@
|
||||||
#include <llvm/Support/TypeSize.h>
|
#include <llvm/Support/TypeSize.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
|
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/half_support.h>
|
#include <torch/csrc/jit/tensorexpr/half_support.h>
|
||||||
|
|
@ -53,13 +52,9 @@ C10_DEFINE_bool(
|
||||||
false,
|
false,
|
||||||
"Use fast (but slightly less accurate) implementations of tanh and sigmoid");
|
"Use fast (but slightly less accurate) implementations of tanh and sigmoid");
|
||||||
|
|
||||||
DEFINE_TRIGGER(llvm_codegen_created);
|
|
||||||
DEFINE_TRIGGER(llvm_codegen_executed);
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
DEFINE_TRIGGER(llvm_codegen_parallel_dispatched);
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
llvm::CmpInst::Predicate llvm_comparison_predicate(
|
llvm::CmpInst::Predicate llvm_comparison_predicate(
|
||||||
|
|
@ -288,7 +283,6 @@ void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
|
||||||
callee(index, packed_data);
|
callee(index, packed_data);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
USE_TRIGGER(llvm_codegen_parallel_dispatched);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorexpr
|
} // namespace tensorexpr
|
||||||
|
|
@ -315,7 +309,6 @@ LLVMCodeGen::LLVMCodeGen(
|
||||||
|
|
||||||
void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
|
void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
|
||||||
value<float>(const_cast<void**>(args.data()));
|
value<float>(const_cast<void**>(args.data()));
|
||||||
USE_TRIGGER(llvm_codegen_executed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
|
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
|
||||||
|
|
@ -333,7 +326,6 @@ void LLVMCodeGen::call(const std::vector<CallArg>& args) {
|
||||||
argv[i] = argToPtr(bufferArg, callArg);
|
argv[i] = argToPtr(bufferArg, callArg);
|
||||||
}
|
}
|
||||||
value<float>(argv.data());
|
value<float>(argv.data());
|
||||||
USE_TRIGGER(llvm_codegen_executed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LLVMCodeGen::empty_strided(
|
at::Tensor LLVMCodeGen::empty_strided(
|
||||||
|
|
@ -438,8 +430,6 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
||||||
jit_->addModule(std::move(module_), std::move(context_));
|
jit_->addModule(std::move(module_), std::move(context_));
|
||||||
auto sym = jit_->findSymbol("wrapper");
|
auto sym = jit_->findSymbol("wrapper");
|
||||||
kernelAddress_ = assertSuccess(sym.getAddress());
|
kernelAddress_ = assertSuccess(sym.getAddress());
|
||||||
|
|
||||||
USE_TRIGGER(llvm_codegen_created);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
|
llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
|
||||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
||||||
|
|
||||||
|
|
@ -17,8 +16,6 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
||||||
DECLARE_TRIGGER(llvm_codegen_parallel_dispatched);
|
|
||||||
|
|
||||||
class LLVMCodeGenImpl;
|
class LLVMCodeGenImpl;
|
||||||
|
|
||||||
class TORCH_API LLVMCodeGen : public CodeGen {
|
class TORCH_API LLVMCodeGen : public CodeGen {
|
||||||
|
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
class ExecutionCounter(object):
|
|
||||||
def try_get_trigger_value(self):
|
|
||||||
try:
|
|
||||||
return torch._C._jit_get_trigger_value(self.name)
|
|
||||||
except Exception:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
self.start_value = self.try_get_trigger_value()
|
|
||||||
|
|
||||||
def elapsed_value(self):
|
|
||||||
value = self.try_get_trigger_value()
|
|
||||||
return value - self.start_value
|
|
||||||
|
|
||||||
class CudaCodeGenCreated(ExecutionCounter):
|
|
||||||
def __init__(self):
|
|
||||||
super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
|
|
||||||
|
|
||||||
class CudaCodeGenExecuted(ExecutionCounter):
|
|
||||||
def __init__(self):
|
|
||||||
super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
|
|
||||||
|
|
||||||
class LLVMCodeGenCreated(ExecutionCounter):
|
|
||||||
def __init__(self):
|
|
||||||
super(LLVMCodeGenCreated, self).__init__("llvm_codegen_created")
|
|
||||||
|
|
||||||
class LLVMCodeGenExecuted(ExecutionCounter):
|
|
||||||
def __init__(self):
|
|
||||||
super(LLVMCodeGenExecuted, self).__init__("llvm_codegen_executed")
|
|
||||||
|
|
||||||
class SimpleIREvalExecuted(ExecutionCounter):
|
|
||||||
def __init__(self):
|
|
||||||
super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")
|
|
||||||
Loading…
Reference in New Issue
Block a user