mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
In this PR, we extract `codegen_unbacked_symbol_defs` of FallbackKernel out as a `codegen_unbacked_symbol_defs_for_outputs` method in wrapper. With it, HOPs can support the case where the subgraph returns a tensor with unbacked symints. This PR only do it for cond, we'll have follow up PRs for others (e.g. while_loop) as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147567 Approved by: https://github.com/jansel
1320 lines
42 KiB
Python
1320 lines
42 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
from torch._higher_order_ops.associative_scan import associative_scan
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
|
|
|
|
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
|
|
result = []
|
|
device = inputs[0].device
|
|
# iterate over the cartesian product of predicate values
|
|
for values in itertools.product(*([possible_values] * num_to_prepend)):
|
|
prepended = [torch.tensor(v, device=device) for v in values]
|
|
result.append((*prepended, *inputs))
|
|
return result
|
|
|
|
|
|
def prepend_predicates(inputs, num_predicates=1):
|
|
return _prepend_product_of_values(inputs, [False, True], num_predicates)
|
|
|
|
|
|
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
|
|
return _prepend_product_of_values(inputs, counter_values, num_counters)
|
|
|
|
|
|
class CondModels:
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
return x + y
|
|
|
|
def false_fn(x, y):
|
|
return x - y
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b])
|
|
|
|
class SimpleWithIntClosure(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.num = 3
|
|
|
|
def forward(self, p, a, b):
|
|
return torch.cond(
|
|
pred=p,
|
|
true_fn=lambda a, b: [a + b + self.num],
|
|
false_fn=lambda a, b: [a - b - self.num],
|
|
operands=(a, b),
|
|
)
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, p0, p1, p2, a, b, c):
|
|
def true_fn(x0, y0, z0):
|
|
def true_true_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) * 3.14
|
|
|
|
def true_false_fn(x1, y1, z1):
|
|
def true_false_true_fn(x2, y2, z2):
|
|
return (x2 * y2 * z2) / 2.71
|
|
|
|
def true_false_false_fn(x2, y2, z2):
|
|
return (x2 + y2 + z2) * 1.23
|
|
|
|
return torch.cond(
|
|
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
|
|
|
|
def false_fn(x0, y0, z0):
|
|
def false_true_fn(x1, y1, z1):
|
|
def false_true_true_fn(x2, y2, z2):
|
|
return (x2 - y2 - z2) + 1.23
|
|
|
|
def false_true_false_fn(x2, y2, z2):
|
|
return (x2 / y2 / z2) - 3.14
|
|
|
|
return torch.cond(
|
|
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
def false_false_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) / 2.71
|
|
|
|
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
|
|
|
|
return torch.cond(p0, true_fn, false_fn, [a, b, c])
|
|
|
|
class Parameters(torch.nn.Module):
|
|
class InnerModel1(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer = torch.nn.Linear(20, 30, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x + 1) * 3.14
|
|
|
|
class InnerModel2(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(20, 10, device=device)
|
|
self.layer2 = torch.nn.Linear(10, 30, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.layer2(self.layer1(x - 2)) * 3.14
|
|
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.true_fn = self.InnerModel1(device)
|
|
self.false_fn = self.InnerModel2(device)
|
|
|
|
def forward(self, p, a):
|
|
return torch.cond(p, self.true_fn, self.false_fn, [a])
|
|
|
|
class ReinterpretView(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
z1 = x + y
|
|
z2 = x - y
|
|
return z1[2:], z2[:, 4:].contiguous()
|
|
|
|
def false_fn(x, y):
|
|
z1 = x - y
|
|
z2 = x + y
|
|
return z1[2:], z2[:, 4:].contiguous()
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
|
|
|
|
class MultipleOutputs(torch.nn.Module):
|
|
def forward(self, p, a, b, c):
|
|
def true_fn(x, y, z):
|
|
return x * y, z / 2.71, (y - x).sum(dim=1)
|
|
|
|
def false_fn(x, y, z):
|
|
return y / x, z * 3.14, (x + y).mean(dim=1)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b, c])
|
|
|
|
class OuterCode(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
c = a * b + 3.14
|
|
d = a / b - 2.71
|
|
|
|
def true_fn(x, y):
|
|
return x + y
|
|
|
|
def false_fn(x, y):
|
|
return x - y
|
|
|
|
e = torch.cond(p, true_fn, false_fn, [c, d])
|
|
|
|
return e * e / 1.41
|
|
|
|
class OuterBuffers(torch.nn.Module):
|
|
def forward(self, p, a, b, c):
|
|
d = a * 2
|
|
e = b / 2
|
|
|
|
def true_fn(x):
|
|
return x + d
|
|
|
|
def false_fn(x):
|
|
return x - e
|
|
|
|
return torch.cond(p, true_fn, false_fn, [c])
|
|
|
|
class WithNonTensorPredicate(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
def true_fn(x, y):
|
|
return x.sum(0) / 3.14
|
|
|
|
def false_fn(x, y):
|
|
return y.sum(0) * 2.71
|
|
|
|
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
|
|
|
|
class UnbackedSymIntClosure(torch.nn.Module):
|
|
def forward(self, p, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.sum().to(torch.int64).item()
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
class MismatchedOutputSize(torch.nn.Module):
|
|
def forward(self, p, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return (x + a)[2:].sin()
|
|
|
|
def false_fn(x):
|
|
return (x + b * z)[:2].cos()
|
|
|
|
return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
|
|
|
|
|
class CondTests(TestCase):
|
|
def _run_test(
|
|
self,
|
|
model,
|
|
inputs,
|
|
device,
|
|
dynamic=False,
|
|
num_predicates=1,
|
|
):
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
|
|
|
|
inputs = [inp.to(device=device) for inp in inputs]
|
|
input_sets = [inputs]
|
|
if dynamic:
|
|
larger_inputs = []
|
|
for inp in inputs:
|
|
# tile every first dim 5x
|
|
tiling = [5] + [1] * (inp.ndim - 1)
|
|
larger_inputs.append(torch.tile(inp, tiling))
|
|
input_sets.append(larger_inputs)
|
|
for inputs in input_sets:
|
|
for inp in inputs:
|
|
# mark every first dim as dynamic
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
|
|
for inputs in input_sets:
|
|
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
|
|
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
|
|
result = model(*inputs_with_predicates)
|
|
result_compiled = compiled_model(*inputs_with_predicates)
|
|
# inputs must not be mutated
|
|
torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
|
|
torch.testing.assert_close(result, result_compiled)
|
|
|
|
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_simple_control_flow(self, device, dynamic):
|
|
# cond control flow without nesting
|
|
self._run_test(
|
|
model=CondModels.Simple(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_simple_with_int_closure(self, device):
|
|
self._run_test(
|
|
model=torch.compile(CondModels.SimpleWithIntClosure(), dynamic=True),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_cond_unbacked_symint_closure(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.UnbackedSymIntClosure(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_control_flow_with_precomputed_size(self):
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv2d = torch.nn.Conv2d(
|
|
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
)
|
|
self.threshold = 20
|
|
|
|
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
|
|
def true_fn(x: torch.Tensor):
|
|
return self.conv2d(x)
|
|
|
|
def false_fn(x: torch.Tensor):
|
|
return self.conv2d(x)
|
|
|
|
return torch.cond(
|
|
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
|
|
)
|
|
|
|
main_model = TestModel().to(GPU_TYPE)
|
|
x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE)
|
|
x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE)
|
|
|
|
opt_model = torch.compile(main_model)
|
|
out1 = main_model(x1, 1)
|
|
opt_out1 = opt_model(x1, 1)
|
|
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
|
|
|
|
out2 = main_model(x2, 30)
|
|
opt_out2 = opt_model(x2, 30)
|
|
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_nested_control_flow(self, device, dynamic):
|
|
# cond control flow with nesting
|
|
self._run_test(
|
|
model=CondModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_predicates=3,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_outer_code_before_after(self, device, dynamic):
|
|
# some code before and after the conditional
|
|
self._run_test(
|
|
model=CondModels.OuterCode(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_multiple_outputs(self, device, dynamic):
|
|
# multiple outputs with different shapes
|
|
self._run_test(
|
|
model=CondModels.MultipleOutputs(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(30, 40),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_advanced_dynamic_shapes(self, device):
|
|
# subgraphs input shapes include symbolic expressions
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
return torch.cat([x - 3, y * 3], dim=1)
|
|
|
|
def false_fn(x, y):
|
|
return torch.cat([x / 3, y - 3], dim=1)
|
|
|
|
c = torch.cat([a, b], dim=0)
|
|
d = c * 2
|
|
e = c / 2
|
|
|
|
return torch.cond(p, true_fn, false_fn, [d, e])
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(
|
|
torch.randn(2, 3, 3),
|
|
torch.randn(4, 3, 3),
|
|
),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_unbacked_symint_outer_to_inner(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.cos(x)
|
|
|
|
def false_fn(x):
|
|
return torch.sin(x)
|
|
|
|
nz = torch.nonzero(a)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [b])
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@torch._inductor.config.patch(size_asserts=False)
|
|
def test_cond_unbacked_symint_inner(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.cos(b)
|
|
|
|
def false_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.sin(b)
|
|
|
|
b = torch.sin(a)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [b])
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@unittest.skip("unbacked symints from inner to outer graph not supported yet")
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_unbacked_symint_inner_to_outer(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.cos(b)
|
|
|
|
def false_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.sin(b)
|
|
|
|
b = torch.sin(a)
|
|
|
|
y = torch.cond(p, true_fn, false_fn, [b])
|
|
return torch.sin(y)
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_use_buffers_from_outer_scope(self):
|
|
# subgraphs input shapes include symbolic expressions
|
|
self._run_test(
|
|
model=CondModels.OuterBuffers(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=False,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_reintepret_view_inputs_outputs(self):
|
|
# ReinterpretView in inputs and outputs of the subgraphs
|
|
self._run_test(
|
|
model=CondModels.ReinterpretView(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_subgraphs_with_parameters(self, device, dynamic):
|
|
# nested Modules with parameters
|
|
self._run_test(
|
|
model=CondModels.Parameters(device),
|
|
inputs=(torch.randn(10, 20),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_non_tensor_predicates(self, device, dynamic):
|
|
# model with a boolean predicate
|
|
for b_size_0 in [5, 15]:
|
|
torch._dynamo.reset()
|
|
self._run_test(
|
|
model=CondModels.WithNonTensorPredicate(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(b_size_0, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_predicates=0,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_aliasing_outputs(self):
|
|
# output aliasing in subgraphs: not supported
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
z = x + y
|
|
return z, z[1:]
|
|
|
|
def false_fn(x, y):
|
|
z = x - y
|
|
return z, z[1:]
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b])
|
|
|
|
# AssertionError: Output aliasing is currently not supported...
|
|
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
|
|
torch.compile(Model())(
|
|
torch.tensor(True),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_decompose_ops_in_subgraph(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.zeros_like(x)
|
|
|
|
def false_fn(x):
|
|
return torch.ones_like(x)
|
|
|
|
b = torch.ones_like(a)
|
|
c = torch.cond(p, true_fn, false_fn, [b])
|
|
return c
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.rand(10, 20),),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_decompose_ops_in_subgraph_recursive(self, device):
|
|
def inner_fn1(x):
|
|
return torch.zeros_like(x)
|
|
|
|
def inner_fn2(x):
|
|
return torch.ones_like(x)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.cond(p, inner_fn2, inner_fn1, [x])
|
|
|
|
def false_fn(x):
|
|
return torch.cond(p, inner_fn1, inner_fn2, [x])
|
|
|
|
b = torch.ones_like(a)
|
|
c = torch.cond(p, true_fn, false_fn, [b])
|
|
return c
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.rand(10, 20),),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_inductor_fx_passes_recursively_applied(self):
|
|
counters = {"pre_grad": 0, "post_grad": 0}
|
|
|
|
def pre_grad_pass_counter(gm):
|
|
counters["pre_grad"] += 1
|
|
|
|
def post_grad_pass_counter(gm):
|
|
counters["post_grad"] += 1
|
|
|
|
with torch._inductor.config.patch(
|
|
{
|
|
"pre_grad_custom_pass": pre_grad_pass_counter,
|
|
"post_grad_custom_pre_pass": post_grad_pass_counter,
|
|
# The above patches don't pickle
|
|
"fx_graph_cache": False,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=CondModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=True,
|
|
num_predicates=3,
|
|
)
|
|
|
|
self.assertEqual(counters["pre_grad"], 11)
|
|
self.assertEqual(counters["post_grad"], 11)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_cond_mismatched_branch_output_size(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.MismatchedOutputSize(),
|
|
inputs={
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
},
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
|
|
class WhileLoopModels:
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, ci, a, b):
|
|
def cond_fn(i, x, y):
|
|
return i > 0
|
|
|
|
def body_fn(i, x, y):
|
|
return i - 1, x + y, y - x
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, ci, cj, a, b):
|
|
def cond_fn(i1, j1, x1, y1):
|
|
return i1 > 0
|
|
|
|
def body_fn(i1, j1, x1, y1):
|
|
def cond_fn_nested(i2, j2, x2, y2):
|
|
return j2 > 0
|
|
|
|
def body_fn_nested(i2, j2, x2, y2):
|
|
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
|
|
|
|
i1, j1, x1, y1 = torch._higher_order_ops.while_loop(
|
|
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
|
|
)
|
|
|
|
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))
|
|
|
|
class Parameters(torch.nn.Module):
|
|
class InnerModel(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(20, 30, device=device)
|
|
self.layer2 = torch.nn.Linear(30, 20, device=device)
|
|
|
|
def forward(self, c, x):
|
|
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
|
|
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.body_fn = self.InnerModel(device)
|
|
self.cond_fn = lambda c, x: c > 0
|
|
|
|
def forward(self, c, a):
|
|
return torch._higher_order_ops.while_loop(
|
|
self.cond_fn, self.body_fn, [c, a]
|
|
)
|
|
|
|
class OuterCode(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a * b + 3.14
|
|
e = a / b - 2.71
|
|
|
|
def cond_fn(c, x, y):
|
|
return c > 0
|
|
|
|
def body_fn(c, x, y):
|
|
return c - 1, y - x, x + y
|
|
|
|
_, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e])
|
|
|
|
return f * g / 1.41
|
|
|
|
# TODO(aakhundov): add while_loop test with outer buffers
|
|
# with dynamic=True once dynamo / export allows while_loop
|
|
# closure capture with mark_dynamic:
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
class OuterBuffers(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a * 2
|
|
e = b / 2
|
|
|
|
def cond_fn(c, x, y):
|
|
return c > 0
|
|
|
|
def body_fn(c, x, y):
|
|
return c - 1, x + d, y - e
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b])
|
|
|
|
class PytreeCarry(torch.nn.Module):
|
|
def forward(self, it, pytree_input):
|
|
def cond_fn(it, pytree_input):
|
|
return it > 0
|
|
|
|
def body_fn(it, pytree_input):
|
|
x = pytree_input[0][0]
|
|
y = pytree_input[1]["x"]
|
|
z = pytree_input[1]["y"]
|
|
new_x = y.sin()
|
|
new_y = z.cos()
|
|
new_z = x + 1
|
|
return it - 1, ([new_x], {"x": new_y, "y": new_z})
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn, body_fn, (it, pytree_input)
|
|
)
|
|
|
|
class DataDependentOpInSubgraph(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, reduced_carry):
|
|
return c > 0
|
|
|
|
def body_fn(c, reduced_carry):
|
|
k = torch.masked_select(a, b)
|
|
d = torch.concat([k, k * 2])
|
|
return c - 1, torch.min(d).unsqueeze(0) + reduced_carry
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, torch.zeros([1], dtype=torch.int64, device=c.device)],
|
|
)
|
|
|
|
class DataDependentInOut(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
inp = torch.zeros(
|
|
a.sum().to(torch.int64).item(), 3, device=a.device, dtype=torch.int64
|
|
)
|
|
|
|
def cond_fn(c, inp):
|
|
return c > 0
|
|
|
|
def body_fn(c, inp):
|
|
return c - 1, (inp.sin() + 1).to(torch.int64)
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, inp],
|
|
)
|
|
|
|
class DataDependentInOutMismatch(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, a, b):
|
|
return c > 0
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a.nonzero(), b.nonzero()
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
class InfiniteLoop(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = a.view(-1, 1)
|
|
|
|
def cond_fn(c, a_view):
|
|
return a_view.size(-1) > 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
|
|
class ZeroLoop(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return a_view.size(-1) == 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class ZeroLoop2(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return False
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class ZeroLoop3(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class UnbackedSymIntClosure(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a.sum().to(torch.int64).item()
|
|
e = torch.nonzero(b).size(0)
|
|
|
|
def cond_fn(c, a, b):
|
|
return c > d + e + a.shape[0] - b.shape[0]
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a + e, b + d
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
class SymExprCond(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a.sum().to(torch.int64).item()
|
|
e = torch.nonzero(b).size(0)
|
|
|
|
def cond_fn(c, a, b):
|
|
return d + e + a.shape[0] - b.shape[0] < 10
|
|
|
|
def body_fn(c, a, b):
|
|
return c + 1, a + e, b + d
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
|
|
class WhileLoopTests(TestCase):
|
|
def _run_test(
|
|
self,
|
|
model,
|
|
inputs,
|
|
device,
|
|
dynamic=False,
|
|
num_counters=1,
|
|
):
|
|
import torch.utils._pytree as pytree
|
|
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
|
|
|
|
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
|
|
input_sets = [inputs]
|
|
if dynamic:
|
|
|
|
def mark_first_dim_dyn(inp):
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
|
|
pytree.tree_map(mark_first_dim_dyn, input_sets)
|
|
|
|
def tile_fn(inp):
|
|
# tile every first dim 5x
|
|
tiling = [5] + [1] * (inp.ndim - 1)
|
|
t = torch.tile(inp, tiling)
|
|
# mark every first dim as dynamic
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
return t
|
|
|
|
larger_inputs = pytree.tree_map(tile_fn, inputs)
|
|
input_sets.append(larger_inputs)
|
|
|
|
for inputs in input_sets:
|
|
flat_inputs, inp_spec = pytree.tree_flatten(inputs)
|
|
for flat_inputs_with_counters in prepend_counters(
|
|
flat_inputs, num_counters
|
|
):
|
|
counters, flat = (
|
|
flat_inputs_with_counters[:num_counters],
|
|
flat_inputs_with_counters[num_counters:],
|
|
)
|
|
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
|
|
inputs_with_counters = counters + unflat_inputs
|
|
cloned_inputs = pytree.tree_map(
|
|
lambda t: t.clone(), inputs_with_counters
|
|
)
|
|
result = model(*inputs_with_counters)
|
|
with torch.no_grad():
|
|
result_compiled = compiled_model(*inputs_with_counters)
|
|
# inputs must not be mutated
|
|
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
|
|
torch.testing.assert_close(
|
|
result, result_compiled, atol=1e-4, rtol=1e-4
|
|
)
|
|
|
|
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_while_loop_simple_control_flow(self, device, dynamic):
|
|
# while_loop control flow without nesting
|
|
self._run_test(
|
|
model=WhileLoopModels.Simple(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_while_loop_nested_control_flow(self, device, dynamic):
|
|
# while_loop control flow with nesting
|
|
self._run_test(
|
|
model=WhileLoopModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_counters=2,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_outer_code(self, device, dynamic):
|
|
# while_loop control flow with outer code
|
|
self._run_test(
|
|
model=WhileLoopModels.OuterCode(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_parameters(self, device, dynamic):
|
|
# while_loop control flow with parameters
|
|
self._run_test(
|
|
model=WhileLoopModels.Parameters(device),
|
|
inputs=(torch.randn(10, 20),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
# dynamic=True doesn't work now due to
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
@parametrize("dynamic", [False])
|
|
def test_while_loop_with_outer_buffers(self, device, dynamic):
|
|
# while_loop control flow with outer code
|
|
self._run_test(
|
|
model=WhileLoopModels.OuterBuffers(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
# dynamic=True doesn't work due to we haven't handle lifted symbols
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_with_pytree_inputs(self, device, dynamic):
|
|
self._run_test(
|
|
model=WhileLoopModels.PytreeCarry(),
|
|
inputs=(
|
|
(
|
|
[torch.randn(10, 20)],
|
|
{"x": torch.randn(10, 20), "y": torch.randn(10, 20)},
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_with_data_dependent_ops(self, device, dynamic):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentOpInSubgraph(),
|
|
inputs=(
|
|
torch.tensor([1, 2, 3, 4, 5]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_with_data_dependent_in_out(self, device, dynamic):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
"capture_scalar_outputs": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentInOut(),
|
|
inputs=(
|
|
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_with_data_dependent_in_out_mismatch(self, dynamic):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected body_fn_output and carried_inputs to have same metadata but found",
|
|
):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentInOutMismatch(),
|
|
inputs=(
|
|
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device="cpu",
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
def test_while_loop_infinite_loop_error(self):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"while_loop doesn't work unless it is captured completely",
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.InfiniteLoop(),
|
|
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
|
|
device="cpu",
|
|
dynamic=False,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_zero_loop(self, device, dynamic):
|
|
for model in [
|
|
WhileLoopModels.ZeroLoop(),
|
|
WhileLoopModels.ZeroLoop2(),
|
|
WhileLoopModels.ZeroLoop3(),
|
|
]:
|
|
self._run_test(
|
|
model=model,
|
|
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch(
|
|
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
|
)
|
|
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
|
|
self._run_test(
|
|
model=WhileLoopModels.UnbackedSymIntClosure(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch(
|
|
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
|
)
|
|
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
|
|
self._run_test(
|
|
model=WhileLoopModels.SymExprCond(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
|
|
class AssociativeScanTests(TestCase):
|
|
@requires_gpu
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("backend", ["inductor"])
|
|
@parametrize("device", [torch.device("cpu"), GPU_TYPE])
|
|
# This test will fail as flip in combination with particular input lenghts
|
|
# produces weird results.
|
|
# This is under investigations in
|
|
# https://github.com/pytorch/pytorch/issues/131805
|
|
@decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE)
|
|
def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
|
|
def fct(x: torch.Tensor, y: torch.Tensor):
|
|
return x + y
|
|
|
|
# for n in range(10):
|
|
for n in [9]:
|
|
x = torch.arange(n, device=device)
|
|
torch.compiler.reset()
|
|
associative_scan1 = torch.compile(
|
|
associative_scan, backend=backend, fullgraph=True
|
|
)
|
|
associative_scan2 = associative_scan
|
|
|
|
if combine_mode == "pointwise" and device == torch.device("cpu"):
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
associative_scan1(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
|
|
# Skipping test because combine_mode currently only suppors CUDA tensors
|
|
return
|
|
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.cumsum(x, 0)
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Flip only non-compiled and compare with compiled reverse=True
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result2 = torch.flip(
|
|
associative_scan2(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Flip only compiled and compare with non-compiled reverse=True
|
|
result1 = torch.flip(
|
|
associative_scan1(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Use reverse=False, but flip both results before and after
|
|
result1 = torch.flip(
|
|
associative_scan1(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result2 = torch.flip(
|
|
associative_scan2(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Reverse=True
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
|
|
instantiate_parametrized_tests(CondTests)
|
|
instantiate_parametrized_tests(WhileLoopTests)
|
|
instantiate_parametrized_tests(AssociativeScanTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_CPU or HAS_GPU:
|
|
run_tests(needs="filelock")
|