Revert "[hop][inductor] track the dependency on unbacked symbols correctly with constant_args for hops (#143456)"

This reverts commit 68a3635484.

Reverted https://github.com/pytorch/pytorch/pull/143456 on behalf of https://github.com/atalman due to New tests are failing internally ([comment](https://github.com/pytorch/pytorch/pull/143456#issuecomment-2631475900))
This commit is contained in:
PyTorch MergeBot 2025-02-03 16:25:58 +00:00
parent 01554c7b5a
commit c0979d72b5
3 changed files with 26 additions and 141 deletions

View File

@ -1305,29 +1305,6 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_cond_unbacked_symint_closure(self, dynamic):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((15, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dynamic_shapes = None
if dynamic:
dim0_a = Dim("s0", min=2, max=1024)
dim0_b = Dim("s1", min=2, max=1024)
dynamic_shapes = {
"p": {},
"x": {0: dim0_a, 1: None},
"y": {0: dim0_b, 1: None},
"z": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.UnbackedSymIntClosure(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
@ -1462,26 +1439,6 @@ class AOTInductorTestsTemplate:
dynamic_shapes=None,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_unbacked_symint_closure(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.UnbackedSymIntClosure(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@config.patch({"is_predispatch": True})
def test_constant(self):
class M(torch.nn.Module):

View File

@ -183,19 +183,6 @@ class CondModels:
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
class UnbackedSymIntClosure(torch.nn.Module):
def forward(self, p, x, y, z):
a = y.shape[0]
b = z.sum().to(torch.int64).item()
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
class CondTests(TestCase):
def _run_test(
@ -261,22 +248,6 @@ class CondTests(TestCase):
device=device,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_cond_unbacked_symint_closure(self, device, dynamic):
self._run_test(
model=CondModels.UnbackedSymIntClosure(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
def test_cond_control_flow_with_precomputed_size(self):
class TestModel(torch.nn.Module):
@ -859,23 +830,6 @@ class WhileLoopModels:
)
return out1 + 1, out2 + 2
class UnbackedSymIntClosure(torch.nn.Module):
def forward(self, c, a, b):
d = a.sum().to(torch.int64).item()
e = torch.nonzero(b).size(0)
def cond_fn(c, a, b):
return c > d + e + a.shape[0] - b.shape[0]
def body_fn(c, a, b):
return c - 1, a + e, b + d
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class WhileLoopTests(TestCase):
def _run_test(
@ -1122,23 +1076,6 @@ class WhileLoopTests(TestCase):
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
self._run_test(
model=WhileLoopModels.UnbackedSymIntClosure(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
class AssociativeScanTests(TestCase):
@requires_gpu

View File

@ -7209,7 +7209,7 @@ class InvokeSubgraph(ExternKernel):
@ir_dataclass(frozen=False)
class Conditional(ExternKernel):
predicate: Optional[IRNode] = None
operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
operands: Optional[list[TensorBox]] = None
true_subgraph: Optional[Subgraph] = None
false_subgraph: Optional[Subgraph] = None
outputs: Optional[list[MultiOutput]] = None
@ -7217,7 +7217,7 @@ class Conditional(ExternKernel):
def __init__(
self,
predicate: IRNode,
operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
operands: list[TensorBox],
true_subgraph: Subgraph,
false_subgraph: Subgraph,
layout: MultiOutputLayout,
@ -7227,13 +7227,15 @@ class Conditional(ExternKernel):
self.true_subgraph = true_subgraph
self.false_subgraph = false_subgraph
sym_args, tensor_args = _split_by_sym_type([predicate] + operands)
inputs = []
if not isinstance(predicate, ShapeAsConstantBuffer):
inputs.append(predicate)
inputs.extend(operands)
super().__init__(
name=None,
layout=layout,
inputs=tensor_args,
constant_args=sym_args,
inputs=inputs,
)
self.name = V.graph.register_buffer(self)
@ -7245,10 +7247,11 @@ class Conditional(ExternKernel):
predicate: TensorBox,
true_fn: Subgraph,
false_fn: Subgraph,
operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
operands: list[TensorBox],
):
predicate = cls.realize_input(predicate)
operands = [cls.realize_input(x) for x in operands]
fx_operands = V.graph.current_node.args[-1]
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
@ -7282,12 +7285,16 @@ class Conditional(ExternKernel):
assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
device = next(
o.get_device()
for o in [predicate] + operands
if not isinstance(o, ShapeAsConstantBuffer)
)
assert device is not None, "cannot determine device"
if not isinstance(predicate, ShapeAsConstantBuffer):
# use predicate device for consistent codegen-ing
device = predicate.get_device()
else:
# predicate is not a Tensor: use first operand's device
assert (
len(operands) > 0
), "When predicate is not a Tensor, there must be at least one operand in torch.cond."
device = operands[0].get_device()
conditional = Conditional(
predicate=predicate,
operands=operands,
@ -7320,32 +7327,18 @@ class Conditional(ExternKernel):
wrapper.codegen_conditional(self)
def _split_by_sym_type(
args: list[Any],
) -> tuple[list[ShapeAsConstantBuffer], list[Any]]:
non_sym_args = []
sym_args = []
for arg in args:
if isinstance(arg, ShapeAsConstantBuffer):
sym_args.append(arg.expr)
else:
non_sym_args.append(arg)
return sym_args, non_sym_args
@ir_dataclass(frozen=False)
class WhileLoop(ExternKernel):
carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
carried_inputs: Optional[list[TensorBox]] = None
additional_inputs: Optional[list[TensorBox]] = None
cond_subgraph: Optional[Subgraph] = None
body_subgraph: Optional[Subgraph] = None
outputs: Optional[list[MultiOutput]] = None
def __init__(
self,
carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
carried_inputs: list[TensorBox],
additional_inputs: list[TensorBox],
cond_subgraph: Subgraph,
body_subgraph: Subgraph,
layout: MultiOutputLayout,
@ -7355,12 +7348,10 @@ class WhileLoop(ExternKernel):
self.cond_subgraph = cond_subgraph
self.body_subgraph = body_subgraph
sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs)
super().__init__(
name=None,
layout=layout,
inputs=tensor_args,
constant_args=sym_args,
inputs=carried_inputs + additional_inputs,
)
self.name = V.graph.register_buffer(self)
@ -7371,8 +7362,8 @@ class WhileLoop(ExternKernel):
cls,
cond_fn: Subgraph,
body_fn: Subgraph,
carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
carried_inputs: list[TensorBox],
additional_inputs: list[TensorBox],
):
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
additional_inputs = [cls.realize_input(x) for x in additional_inputs]