diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 6c0ff3f0ce0..710f66e6900 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -311,6 +311,33 @@ class CPUReproTests(TestCase): (v,), ) + def test_conv1d_strided_weight_torch_compile(self): + def fn(x, w): + wt = w.transpose(2, 1) + y = F.conv1d(x, wt) + return y.clone() + + x_eager = torch.randn(2, 3, 5, requires_grad=True) + w_eager = torch.randn(4, 2, 3, requires_grad=True) + + out_eager = fn(x_eager, w_eager) + grad = torch.randn_like(out_eager) + out_eager_val = out_eager.detach() + out_eager.backward(grad) + grad_x_eager = x_eager.grad.detach().clone() + grad_w_eager = w_eager.grad.detach().clone() + + x_comp = x_eager.detach().requires_grad_(True) + w_comp = w_eager.detach().requires_grad_(True) + compiled = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True) + out_comp = compiled(x_comp, w_comp) + out_comp_val = out_comp.detach() + out_comp.backward(grad) + + torch.testing.assert_close(out_comp_val, out_eager_val) + torch.testing.assert_close(x_comp.grad, grad_x_eager) + torch.testing.assert_close(w_comp.grad, grad_w_eager) + @config.patch(freezing=True) @unittest.skipIf(not TEST_MKL, "Test requires MKL") @patch("torch.cuda.is_available", lambda: False) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 1184f456cc7..3d6313c0fa4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4639,6 +4639,41 @@ class CommonTemplate: (torch.randn([4, 4, 4]),), ) + def test_conv1d_with_permute(self): + # fix https://github.com/pytorch/pytorch/issues/159462 + class ConvModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(1, 64, kernel_size=3, padding=1) + + def forward(self, x): + x = x.permute(0, 2, 1) + return self.conv(x) + + self.common(ConvModel(), (torch.randn([32, 100, 1]),), check_lowp=False) + + def test_conv1d_depthwise(self): + class ConvModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=(9,), + stride=(1,), + padding=(4,), + groups=768, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + + input_tensor = torch.randn([1, 768, 512]).as_strided( + (1, 768, 512), (393216, 1, 768) + ) + self.common(ConvModel(), (input_tensor,), check_lowp=False) + def test_convolution1(self): m = torch.nn.Sequential( torch.nn.Conv2d(5, 6, [3, 3]), diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index e45c51becde..54baa27adc4 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -110,6 +110,7 @@ test_failures = { # # Failed to find dynamic for loop variable: # + "test_conv1d_with_permute_dynamic_shapes": TestFailure(("cpu",), is_skip=True), "test_arange1_dynamic_shapes": TestFailure(("cpu",)), "test_arange2_dynamic_shapes": TestFailure(("cpu",)), "test_arange3_dynamic_shapes": TestFailure(("cpu",)), diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index 9300bf1bae8..349160a1e6c 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -134,11 +134,11 @@ class TestUtils(TestCase): torch.Tensor(2, 2, 3), torch.Tensor(2, 2, 2), torch.Tensor(2), - (1, 1), - (0, 0), - (1, 1), + (1,), + (0,), + (1,), True, - (0, 0), + (0,), 1, ), {}, diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7cdeb0f8d03..0a5b6faab2f 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -201,6 +201,26 @@ class FakeTensorTest(TestCase): self.assertEqual(torch.ones([10]), out[0]) + def test_conv_nhwc(self): + x = torch.randn([1, 1024, 16, 16]).to(memory_format=torch.channels_last) + w = torch.randn([256, 1024, 4, 4]).to(memory_format=torch.channels_last) + b = torch.randn([256]) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w, b): + return torch.ops.aten.convolution( + x, w, b, [1, 1], [0, 0], [1, 1], False, [0, 0], 1 + ) + + model = Model() + with FakeTensorMode(allow_non_fake_inputs=True) as mode: + fake_out = model.forward(x, w, b) + eager_out = model.forward(x, w, b) + self.assertEqual(fake_out.stride(), eager_out.stride()) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_zero_dim(self): with FakeTensorMode() as mode: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 1b067df9f4d..64976623060 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2447,18 +2447,6 @@ def meta_conv( output_padding: list[int], groups: int, ): - def pick_memory_format(): - if device_hint(input_tensor) == "cuda": - if is_channels_last(input_tensor) or is_channels_last(weight): - return torch.channels_last - else: - if is_channels_last(input_tensor): - return torch.channels_last - if input_tensor.is_contiguous(memory_format=torch.contiguous_format): - return torch.contiguous_format - elif input_tensor.is_contiguous(memory_format=torch.preserve_format): - return torch.preserve_format - shape_out = calc_conv_nd_return_shape( input_tensor, weight, @@ -2476,7 +2464,6 @@ def meta_conv( shape_out[output_channels_dim] = 0 out = input_tensor.new_empty(shape_out) - out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] return out diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index cefff832c5f..4494be7361b 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -1021,8 +1021,6 @@ def conv(fake_mode, func, *args, **kwargs): # TODO: We can make this a little more faithful with best effort # channels last detection (but only if it's statically obvious!) mem_fmt = None - elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: - mem_fmt = None else: if func is aten.convolution.default: conv_backend = torch._C._select_conv_backend(**kwargs) @@ -1039,15 +1037,40 @@ def conv(fake_mode, func, *args, **kwargs): groups=kwargs["groups"], bias_sizes=kwargs["bias_sizes"], ) + # Expand 1d -> 2d. + # Note: Avoid expanding before calling _select_conv_backend, + # as the function handles 2D expansion internally. + if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + # Note: Using input.to(memory_format=contiguous) does not work. + kwargs["input"] = kwargs["input"].contiguous().unsqueeze(2) + kwargs["weight"] = kwargs["weight"].unsqueeze(2) + if len(kwargs["stride"]) == 1: + kwargs["stride"].insert(0, 1) + kwargs["padding"].insert(0, 0) + kwargs["dilation"].insert(0, 1) + kwargs["output_padding"].insert(0, 0) mem_fmt = torch._C._conv_determine_backend_memory_format( kwargs["input"], kwargs["weight"], conv_backend ) + # revert 2d -> 1d + if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + kwargs["input"] = kwargs["input"].squeeze(2) + kwargs["weight"] = kwargs["weight"].squeeze(2) + if len(kwargs["stride"]) == 2: + kwargs["stride"].pop(0) + kwargs["padding"].pop(0) + kwargs["dilation"].pop(0) + kwargs["output_padding"].pop(0) def convert(t, mem_fmt): if t is None: return t if mem_fmt is not None: - t = t.to(memory_format=mem_fmt) + # channels last only support 4d, try to expand dim then convert it back later. + if t.dim() == 3 and mem_fmt == torch.channels_last: + t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2) + else: + t = t.to(memory_format=mem_fmt) return FakeTensor(fake_mode, t, device) with in_kernel_invocation_manager(fake_mode):