mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add test for capture_dynamic_output_shape_ops=True changing expected output between eager and compiled versions (#145821)
Followup from https://github.com/pytorch/pytorch/issues/130290 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145821 Approved by: https://github.com/eellison, https://github.com/ezyang
This commit is contained in:
parent
776bdb962c
commit
8696e59ae2
|
|
@ -6989,6 +6989,25 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
inputs = [torch.randn(10, 10) for _ in range(4)]
|
||||
self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs)))))
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_argwhere_with_dynamic_shapes(self):
|
||||
def fn(
|
||||
tensor: torch.Tensor,
|
||||
mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
xx, yy = torch.meshgrid(mapping, tensor, indexing="ij")
|
||||
indices = torch.argwhere(xx == yy)
|
||||
|
||||
mapped_values = torch.zeros_like(tensor)
|
||||
mapped_values[indices[:, 1]] = indices[:, 0]
|
||||
|
||||
return mapped_values
|
||||
|
||||
tensor = torch.tensor([1, 2, 3, 5, 6, 7])
|
||||
mapping = torch.tensor([0, 3, 4, 5, 7])
|
||||
opt = torch.compile(fn, fullgraph=True)
|
||||
self.assertEqual(fn(tensor, mapping), opt(tensor, mapping))
|
||||
|
||||
def test_torch_package_working_with_trace(self):
|
||||
# from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user