mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[hop][inductor] track the dependency on unbacked symbols correctly with constant_args for hops (#143456)"
This reverts commit 68a3635484.
Reverted https://github.com/pytorch/pytorch/pull/143456 on behalf of https://github.com/atalman due to New tests are failing internally ([comment](https://github.com/pytorch/pytorch/pull/143456#issuecomment-2631475900))
This commit is contained in:
parent
01554c7b5a
commit
c0979d72b5
|
|
@ -1305,29 +1305,6 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_cond_unbacked_symint_closure(self, dynamic):
|
||||
inputs = (
|
||||
torch.randn((10, 20), device=self.device),
|
||||
torch.randn((15, 20), device=self.device),
|
||||
torch.randn((10, 20), device=self.device),
|
||||
)
|
||||
dynamic_shapes = None
|
||||
if dynamic:
|
||||
dim0_a = Dim("s0", min=2, max=1024)
|
||||
dim0_b = Dim("s1", min=2, max=1024)
|
||||
dynamic_shapes = {
|
||||
"p": {},
|
||||
"x": {0: dim0_a, 1: None},
|
||||
"y": {0: dim0_b, 1: None},
|
||||
"z": {0: dim0_a, 1: None},
|
||||
}
|
||||
self.check_model_with_multiple_inputs(
|
||||
CondModels.UnbackedSymIntClosure(),
|
||||
prepend_predicates(inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_cond_symint_input(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
|
|
@ -1462,26 +1439,6 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=None,
|
||||
)
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_while_loop_with_unbacked_symint_closure(self, dynamic):
|
||||
inputs = (
|
||||
torch.randn(10, 20, device=self.device),
|
||||
torch.randn(10, 20, device=self.device),
|
||||
)
|
||||
dim0_ab = Dim("s0", min=2, max=1024)
|
||||
dynamic_shapes = None
|
||||
if dynamic:
|
||||
dynamic_shapes = {
|
||||
"c": {},
|
||||
"a": {0: dim0_ab, 1: None},
|
||||
"b": {0: dim0_ab, 1: None},
|
||||
}
|
||||
self.check_model_with_multiple_inputs(
|
||||
WhileLoopModels.UnbackedSymIntClosure(),
|
||||
prepend_counters(inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@config.patch({"is_predispatch": True})
|
||||
def test_constant(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -183,19 +183,6 @@ class CondModels:
|
|||
|
||||
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
|
||||
|
||||
class UnbackedSymIntClosure(torch.nn.Module):
|
||||
def forward(self, p, x, y, z):
|
||||
a = y.shape[0]
|
||||
b = z.sum().to(torch.int64).item()
|
||||
|
||||
def true_fn(x):
|
||||
return x + a
|
||||
|
||||
def false_fn(x):
|
||||
return x + b * z
|
||||
|
||||
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
||||
|
||||
|
||||
class CondTests(TestCase):
|
||||
def _run_test(
|
||||
|
|
@ -261,22 +248,6 @@ class CondTests(TestCase):
|
|||
device=device,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_cond_unbacked_symint_closure(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=CondModels.UnbackedSymIntClosure(),
|
||||
inputs=(
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_cond_control_flow_with_precomputed_size(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
|
|
@ -859,23 +830,6 @@ class WhileLoopModels:
|
|||
)
|
||||
return out1 + 1, out2 + 2
|
||||
|
||||
class UnbackedSymIntClosure(torch.nn.Module):
|
||||
def forward(self, c, a, b):
|
||||
d = a.sum().to(torch.int64).item()
|
||||
e = torch.nonzero(b).size(0)
|
||||
|
||||
def cond_fn(c, a, b):
|
||||
return c > d + e + a.shape[0] - b.shape[0]
|
||||
|
||||
def body_fn(c, a, b):
|
||||
return c - 1, a + e, b + d
|
||||
|
||||
return torch._higher_order_ops.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
[c, a, b],
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopTests(TestCase):
|
||||
def _run_test(
|
||||
|
|
@ -1122,23 +1076,6 @@ class WhileLoopTests(TestCase):
|
|||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@torch._dynamo.config.patch(
|
||||
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
||||
)
|
||||
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.UnbackedSymIntClosure(),
|
||||
inputs=(
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
|
||||
class AssociativeScanTests(TestCase):
|
||||
@requires_gpu
|
||||
|
|
|
|||
|
|
@ -7209,7 +7209,7 @@ class InvokeSubgraph(ExternKernel):
|
|||
@ir_dataclass(frozen=False)
|
||||
class Conditional(ExternKernel):
|
||||
predicate: Optional[IRNode] = None
|
||||
operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
|
||||
operands: Optional[list[TensorBox]] = None
|
||||
true_subgraph: Optional[Subgraph] = None
|
||||
false_subgraph: Optional[Subgraph] = None
|
||||
outputs: Optional[list[MultiOutput]] = None
|
||||
|
|
@ -7217,7 +7217,7 @@ class Conditional(ExternKernel):
|
|||
def __init__(
|
||||
self,
|
||||
predicate: IRNode,
|
||||
operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
operands: list[TensorBox],
|
||||
true_subgraph: Subgraph,
|
||||
false_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
|
|
@ -7227,13 +7227,15 @@ class Conditional(ExternKernel):
|
|||
self.true_subgraph = true_subgraph
|
||||
self.false_subgraph = false_subgraph
|
||||
|
||||
sym_args, tensor_args = _split_by_sym_type([predicate] + operands)
|
||||
inputs = []
|
||||
if not isinstance(predicate, ShapeAsConstantBuffer):
|
||||
inputs.append(predicate)
|
||||
inputs.extend(operands)
|
||||
|
||||
super().__init__(
|
||||
name=None,
|
||||
layout=layout,
|
||||
inputs=tensor_args,
|
||||
constant_args=sym_args,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
self.name = V.graph.register_buffer(self)
|
||||
|
|
@ -7245,10 +7247,11 @@ class Conditional(ExternKernel):
|
|||
predicate: TensorBox,
|
||||
true_fn: Subgraph,
|
||||
false_fn: Subgraph,
|
||||
operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
operands: list[TensorBox],
|
||||
):
|
||||
predicate = cls.realize_input(predicate)
|
||||
operands = [cls.realize_input(x) for x in operands]
|
||||
|
||||
fx_operands = V.graph.current_node.args[-1]
|
||||
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
|
||||
|
||||
|
|
@ -7282,12 +7285,16 @@ class Conditional(ExternKernel):
|
|||
assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
|
||||
assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
|
||||
|
||||
device = next(
|
||||
o.get_device()
|
||||
for o in [predicate] + operands
|
||||
if not isinstance(o, ShapeAsConstantBuffer)
|
||||
)
|
||||
assert device is not None, "cannot determine device"
|
||||
if not isinstance(predicate, ShapeAsConstantBuffer):
|
||||
# use predicate device for consistent codegen-ing
|
||||
device = predicate.get_device()
|
||||
else:
|
||||
# predicate is not a Tensor: use first operand's device
|
||||
assert (
|
||||
len(operands) > 0
|
||||
), "When predicate is not a Tensor, there must be at least one operand in torch.cond."
|
||||
device = operands[0].get_device()
|
||||
|
||||
conditional = Conditional(
|
||||
predicate=predicate,
|
||||
operands=operands,
|
||||
|
|
@ -7320,32 +7327,18 @@ class Conditional(ExternKernel):
|
|||
wrapper.codegen_conditional(self)
|
||||
|
||||
|
||||
def _split_by_sym_type(
|
||||
args: list[Any],
|
||||
) -> tuple[list[ShapeAsConstantBuffer], list[Any]]:
|
||||
non_sym_args = []
|
||||
sym_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ShapeAsConstantBuffer):
|
||||
sym_args.append(arg.expr)
|
||||
else:
|
||||
non_sym_args.append(arg)
|
||||
|
||||
return sym_args, non_sym_args
|
||||
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class WhileLoop(ExternKernel):
|
||||
carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
|
||||
additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None
|
||||
carried_inputs: Optional[list[TensorBox]] = None
|
||||
additional_inputs: Optional[list[TensorBox]] = None
|
||||
cond_subgraph: Optional[Subgraph] = None
|
||||
body_subgraph: Optional[Subgraph] = None
|
||||
outputs: Optional[list[MultiOutput]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
carried_inputs: list[TensorBox],
|
||||
additional_inputs: list[TensorBox],
|
||||
cond_subgraph: Subgraph,
|
||||
body_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
|
|
@ -7355,12 +7348,10 @@ class WhileLoop(ExternKernel):
|
|||
self.cond_subgraph = cond_subgraph
|
||||
self.body_subgraph = body_subgraph
|
||||
|
||||
sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs)
|
||||
super().__init__(
|
||||
name=None,
|
||||
layout=layout,
|
||||
inputs=tensor_args,
|
||||
constant_args=sym_args,
|
||||
inputs=carried_inputs + additional_inputs,
|
||||
)
|
||||
|
||||
self.name = V.graph.register_buffer(self)
|
||||
|
|
@ -7371,8 +7362,8 @@ class WhileLoop(ExternKernel):
|
|||
cls,
|
||||
cond_fn: Subgraph,
|
||||
body_fn: Subgraph,
|
||||
carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]],
|
||||
carried_inputs: list[TensorBox],
|
||||
additional_inputs: list[TensorBox],
|
||||
):
|
||||
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
|
||||
additional_inputs = [cls.realize_input(x) for x in additional_inputs]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user