mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add more CPU tests (#47369)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47369 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805251 Pulled By: eellison fbshipit-source-id: f1a8210ffdc3cc88354cb4896652151d83a0345a
This commit is contained in:
parent
b8a1070ec0
commit
fe81faee5f
|
|
@ -170,18 +170,17 @@ class TestTEFuser(JitTestCase):
|
|||
traced_f = torch.jit.trace(f, (x, y,))
|
||||
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_broadcast_cuda(self):
|
||||
def scaleshift(x, scale, shift):
|
||||
return x * scale + shift
|
||||
def test_broadcast(self):
|
||||
for device in self.devices:
|
||||
def scaleshift(x, scale, shift):
|
||||
return x * scale + shift
|
||||
|
||||
inputs = [
|
||||
torch.randn(4, 4, dtype=torch.float, device='cuda'),
|
||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
||||
]
|
||||
self.checkScript(scaleshift, inputs)
|
||||
self.assertLastGraphAllFused()
|
||||
inputs = [
|
||||
torch.randn(4, 4, dtype=torch.float, device=device),
|
||||
torch.randn(4, dtype=torch.float, device=device),
|
||||
torch.randn(4, dtype=torch.float, device=device),
|
||||
]
|
||||
self.checkScript(scaleshift, inputs)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
||||
|
|
@ -219,33 +218,33 @@ class TestTEFuser(JitTestCase):
|
|||
grads_half = [t.half() for t in grads]
|
||||
self.assertEqual(grads_half, fusion_grads)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_checks_cat_inputs(self):
|
||||
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
||||
# need to be checked for having the same map size, before we can
|
||||
# run the kernel.
|
||||
def f(x, y):
|
||||
return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
|
||||
for device in self.devices:
|
||||
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
||||
# need to be checked for having the same map size, before we can
|
||||
# run the kernel.
|
||||
def f(x, y):
|
||||
return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
|
||||
|
||||
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
||||
# shape 3x4, and not 4x4.
|
||||
x = torch.randn(2, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
||||
# shape 3x4, and not 4x4.
|
||||
x = torch.randn(2, 4, dtype=torch.float, device=device)
|
||||
y = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||
|
||||
scripted = self.checkScript(f, (x, y))
|
||||
self.assertEqual(scripted(x, y).shape, (3, 4))
|
||||
self.assertAllFused(scripted.graph_for(x, y))
|
||||
scripted = self.checkScript(f, (x, y))
|
||||
self.assertEqual(scripted(x, y).shape, (3, 4))
|
||||
self.assertAllFused(scripted.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
||||
def test_chunk_cuda(self):
|
||||
def fn(x):
|
||||
a, b, c = x.chunk(3, 1)
|
||||
return a * b + c
|
||||
def test_chunk(self):
|
||||
for device in self.devices:
|
||||
def fn(x):
|
||||
a, b, c = x.chunk(3, 1)
|
||||
return a * b + c
|
||||
|
||||
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
|
||||
inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
|
||||
|
||||
self.checkScript(fn, inputs)
|
||||
self.assertLastGraphAllFused()
|
||||
self.checkScript(fn, inputs)
|
||||
self.assertLastGraphAllFused()
|
||||
|
||||
@staticmethod
|
||||
def _test_chunk_correctness(self, device='cpu'):
|
||||
|
|
@ -303,99 +302,96 @@ class TestTEFuser(JitTestCase):
|
|||
"ConstantChunk", 1, exactly=True
|
||||
).run(str(graph))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_chunk_motion_deduplicates_inputs(self):
|
||||
def func1(x):
|
||||
z = x * x
|
||||
z0, z1 = z.chunk(2)
|
||||
return z0 * z1
|
||||
for device in self.devices:
|
||||
def func1(x):
|
||||
z = x * x
|
||||
z0, z1 = z.chunk(2)
|
||||
return z0 * z1
|
||||
|
||||
def func2(x):
|
||||
z = x * x * x
|
||||
z0, z1 = z.chunk(2)
|
||||
return z0 * z1
|
||||
def func2(x):
|
||||
z = x * x * x
|
||||
z0, z1 = z.chunk(2)
|
||||
return z0 * z1
|
||||
|
||||
inputs = [
|
||||
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
|
||||
]
|
||||
for func in [func1, func2]:
|
||||
self.checkScript(func, inputs)
|
||||
self.assertLastGraphAllFused()
|
||||
inputs = [
|
||||
torch.tensor([1.1, 1.2], device=device, dtype=torch.float),
|
||||
]
|
||||
for func in [func1, func2]:
|
||||
self.checkScript(func, inputs)
|
||||
self.assertLastGraphAllFused()
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
||||
def test_chunk_multiple_cuda(self):
|
||||
# The arguments are intentionally used out of order as a test to see
|
||||
# if the fusion compiler adds extra args in the correct order
|
||||
def fn(s, x, y, z):
|
||||
z1, z2 = z.chunk(2, 2)
|
||||
x1, x2, x3 = x.chunk(3, 1)
|
||||
y1, y2 = y.chunk(2, 0)
|
||||
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
||||
def test_chunk_multiple(self):
|
||||
for device in self.devices:
|
||||
# The arguments are intentionally used out of order as a test to see
|
||||
# if the fusion compiler adds extra args in the correct order
|
||||
def fn(s, x, y, z):
|
||||
z1, z2 = z.chunk(2, 2)
|
||||
x1, x2, x3 = x.chunk(3, 1)
|
||||
y1, y2 = y.chunk(2, 0)
|
||||
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
||||
|
||||
inputs = [
|
||||
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
|
||||
]
|
||||
inputs = [
|
||||
torch.randn(5, 2, 3, dtype=torch.float, device=device),
|
||||
torch.randn(5, 6, 3, dtype=torch.float, device=device),
|
||||
torch.randn(10, 2, 3, dtype=torch.float, device=device),
|
||||
torch.randn(5, 2, 6, dtype=torch.float, device=device),
|
||||
]
|
||||
|
||||
ge = self.checkScript(fn, inputs)
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
ge = self.checkScript(fn, inputs)
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
|
||||
def test_minmax(self):
|
||||
def tmax(a, b):
|
||||
return torch.max(2 * a, b)
|
||||
for device in self.devices:
|
||||
def tmax(a, b):
|
||||
return torch.max(2 * a, b)
|
||||
|
||||
def tmin(a, b):
|
||||
return torch.min(2 * a, b)
|
||||
def tmin(a, b):
|
||||
return torch.min(2 * a, b)
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float)
|
||||
b = torch.randn(4, 4, dtype=torch.float)
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float)
|
||||
a = torch.randn(4, 4, dtype=torch.float)
|
||||
b = torch.randn(4, 4, dtype=torch.float)
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float)
|
||||
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda")
|
||||
for f, inputs, device in product(
|
||||
(tmax, tmin),
|
||||
([a, b], [a, nan], [b, nan]),
|
||||
devices):
|
||||
inputs = [t.to(device) for t in inputs]
|
||||
s = self.checkScript(f, inputs)
|
||||
self.assertAllFused(s.graph_for(*inputs))
|
||||
for f, inputs, device in product(
|
||||
(tmax, tmin),
|
||||
([a, b], [a, nan], [b, nan]),
|
||||
self.devices):
|
||||
inputs = [t.to(device) for t in inputs]
|
||||
s = self.checkScript(f, inputs)
|
||||
self.assertAllFused(s.graph_for(*inputs))
|
||||
|
||||
# TODO: reenable the test after backwards passes start working in PE
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_clamp(self):
|
||||
def func2(a, b):
|
||||
return torch.clamp(a + b, min=0, max=2)
|
||||
for device in self.devices:
|
||||
def func2(a, b):
|
||||
return torch.clamp(a + b, min=0, max=2)
|
||||
|
||||
def funcInf(a, b):
|
||||
return torch.clamp(a + b, min=0, max=float('inf'))
|
||||
def funcInf(a, b):
|
||||
return torch.clamp(a + b, min=0, max=float('inf'))
|
||||
|
||||
def funcNegInf(a, b):
|
||||
return torch.clamp(a + b, min=float('-inf'), max=0)
|
||||
def funcNegInf(a, b):
|
||||
return torch.clamp(a + b, min=float('-inf'), max=0)
|
||||
|
||||
def funcOptMin(a, b):
|
||||
return torch.clamp(a + b, max=2)
|
||||
def funcOptMin(a, b):
|
||||
return torch.clamp(a + b, max=2)
|
||||
|
||||
def funcOptMax(a, b):
|
||||
return torch.clamp(a + b, min=0)
|
||||
def funcOptMax(a, b):
|
||||
return torch.clamp(a + b, min=0)
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
|
||||
a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
|
||||
b = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float, device=device)
|
||||
|
||||
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
|
||||
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
||||
inp1, inp2 = inputs
|
||||
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
|
||||
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
|
||||
c = s(inp1, inp2)
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
warmup_backward(c.sum())
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
|
||||
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
|
||||
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
||||
inp1, inp2 = inputs
|
||||
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
|
||||
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
|
||||
c = s(inp1, inp2)
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
warmup_backward(c.sum())
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
|
|
@ -425,31 +421,31 @@ class TestTEFuser(JitTestCase):
|
|||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||
self.assertAllFused(ge.graph_for(x, y, z))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_mul_bool(self):
|
||||
def f(x, y, z):
|
||||
return x * y * z
|
||||
for device in self.devices:
|
||||
def f(x, y, z):
|
||||
return x * y * z
|
||||
|
||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
||||
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||
|
||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||
self.assertAllFused(ge.graph_for(x, y, z))
|
||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||
self.assertAllFused(ge.graph_for(x, y, z))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_div_bool(self):
|
||||
def f(x, y, z):
|
||||
return (x + y) / z
|
||||
for device in self.devices:
|
||||
def f(x, y, z):
|
||||
return (x + y) / z
|
||||
|
||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
||||
z = torch.ones_like(x, dtype=torch.bool, device='cuda')
|
||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||
z = torch.ones_like(x, dtype=torch.bool, device=device)
|
||||
|
||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||
self.assertAllFused(ge.graph_for(x, y, z))
|
||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||
self.assertAllFused(ge.graph_for(x, y, z))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval")
|
||||
def test_bitwise_ops(self):
|
||||
def apply(fn):
|
||||
return lambda x, y, z: fn(fn(x, y), z)
|
||||
|
|
@ -467,7 +463,7 @@ class TestTEFuser(JitTestCase):
|
|||
operator.__or__,
|
||||
operator.__xor__
|
||||
]
|
||||
devices = ["cuda"]
|
||||
devices = self.devices
|
||||
for dtype, op, device in product(dtypes, binary_ops, devices):
|
||||
try:
|
||||
x = self.data_for(dtype, device)
|
||||
|
|
@ -528,20 +524,20 @@ class TestTEFuser(JitTestCase):
|
|||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_comparison_eq_ne(self):
|
||||
def f(x, y):
|
||||
mask = (x == 0).type_as(x)
|
||||
z = x * mask + y
|
||||
mask = (x != 0).type_as(x)
|
||||
z = z * mask + y
|
||||
return z
|
||||
for device in self.devices:
|
||||
def f(x, y):
|
||||
mask = (x == 0).type_as(x)
|
||||
z = x * mask + y
|
||||
mask = (x != 0).type_as(x)
|
||||
z = z * mask + y
|
||||
return z
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@staticmethod
|
||||
def fn_test_comparison_gt_lt(x, y):
|
||||
|
|
@ -551,47 +547,47 @@ class TestTEFuser(JitTestCase):
|
|||
z = z * mask + y
|
||||
return z
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_comparison_gt_lt_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
def test_comparison_gt_lt(self):
|
||||
for device in self.devices:
|
||||
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
|
||||
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_comparison_ge_le_cuda(self):
|
||||
def f(x, y):
|
||||
mask = (x >= 0).type_as(x)
|
||||
z = x * mask + y
|
||||
mask = (x <= 0).type_as(x)
|
||||
z = z * mask + y
|
||||
return z
|
||||
def test_comparison_ge_le(self):
|
||||
for device in self.devices:
|
||||
def f(x, y):
|
||||
mask = (x >= 0).type_as(x)
|
||||
z = x * mask + y
|
||||
mask = (x <= 0).type_as(x)
|
||||
z = z * mask + y
|
||||
return z
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
x.requires_grad_(True)
|
||||
y.requires_grad_(True)
|
||||
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
x.requires_grad_(True)
|
||||
y.requires_grad_(True)
|
||||
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_addcmul_cuda(self):
|
||||
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
||||
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
def test_addcmul(self):
|
||||
for device in self.devices:
|
||||
t = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||
t1 = torch.randn(4, 1, dtype=torch.float, device=device)
|
||||
t2 = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||
|
||||
def foo(t, t1, t2):
|
||||
return t.addcmul(t + 1, t2, value=0.1)
|
||||
def foo(t, t1, t2):
|
||||
return t.addcmul(t + 1, t2, value=0.1)
|
||||
|
||||
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
|
||||
graph = ge.graph_for(t, t1, t2)
|
||||
fusion_groups = self.findFusionGroups(graph)
|
||||
self.assertEqual(len(fusion_groups), 1)
|
||||
FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0]))
|
||||
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
|
||||
graph = ge.graph_for(t, t1, t2)
|
||||
fusion_groups = self.findFusionGroups(graph)
|
||||
self.assertEqual(len(fusion_groups), 1)
|
||||
FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0]))
|
||||
|
||||
# TODO: We leak CUDA memory here because the traced graph holds onto a
|
||||
# constant-ified tensor. Since the Python-global CompilationUnit is alive
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ def _jit_can_fuse_on_cpu() -> _bool: ...
|
|||
def _jit_can_fuse_on_gpu() -> _bool: ...
|
||||
def _jit_texpr_fuser_enabled() -> _bool: ...
|
||||
def _jit_nvfuser_enabled() -> _bool: ...
|
||||
def _llvm_enabled() -> _bool: ...
|
||||
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
||||
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
||||
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
||||
|
|
|
|||
|
|
@ -629,6 +629,13 @@ void initJITBindings(PyObject* module) {
|
|||
using namespace torch::jit::tensorexpr;
|
||||
getTEMustUseLLVMOnCPU() = use_llvm;
|
||||
})
|
||||
.def("_llvm_enabled", []() {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_fuse_tensorexprs",
|
||||
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user