diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index 08db1d202b4..fb470e568ce 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) { std::tuple> tril_batch_rule( const Tensor& self, std::optional self_bdim, - int64_t diagonal = 0) { + c10::SymInt diagonal = 0) { TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions."); auto self_ = moveBatchDimToFront(self, self_bdim); - auto result = at::tril(self_, diagonal); + auto result = at::tril_symint(self_, std::move(diagonal)); return std::make_tuple(std::move(result), 0); } std::tuple> triu_batch_rule( const Tensor& self, std::optional self_bdim, - int64_t diagonal = 0) { + c10::SymInt diagonal = 0) { TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions."); auto self_ = moveBatchDimToFront(self, self_bdim); - auto result = at::triu(self_, diagonal); + auto result = at::triu_symint(self_, std::move(diagonal)); return std::make_tuple(std::move(result), 0); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ad3c75f2097..d76a7385907 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8865,11 +8865,11 @@ autogen: bitwise_right_shift.Scalar_Tensor_out tags: pointwise -- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +- func: tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) structured_delegate: tril.out variants: method -- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +- func: triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) structured_delegate: triu.out variants: method @@ -8993,25 +8993,25 @@ - func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor variants: method, function -- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) +- func: triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU: triu_cpu CUDA: triu_cuda MPS: triu_mps_out -- func: triu(Tensor self, int diagonal=0) -> Tensor +- func: triu(Tensor self, SymInt diagonal=0) -> Tensor structured_delegate: triu.out variants: method, function -- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) +- func: tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU: tril_cpu CUDA: tril_cuda MPS: tril_mps_out -- func: tril(Tensor self, int diagonal=0) -> Tensor +- func: tril(Tensor self, SymInt diagonal=0) -> Tensor structured_delegate: tril.out variants: method, function diff --git a/test/export/test_export.py b/test/export/test_export.py index 762ad512ae3..61b7b886a71 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16745,6 +16745,74 @@ def forward(self, q, k, v): self.assertEqual(result_non_strict, result_strict) + def test_tril_dynamic_diagonal(self): + class Module(torch.nn.Module): + def forward(self, x, y): + x_len = x.shape[0] + y_len = y.shape[0] + mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device) + mask = mask.tril(diagonal=y_len - x_len) + return mask + + x = torch.randn(3, 4) + y = torch.randn(5, 4) + x_len = Dim("x_len", min=1, max=64) + y_len = Dim("y_len", min=1, max=64) + ep = export( + Module(), + (x, y), + dynamic_shapes={ + "x": {0: x_len}, + "y": {0: y_len}, + }, + ) + eager_out = Module()(x, y) + exported_out = ep.module()(x, y) + self.assertEqual(eager_out, exported_out) + self.assertEqual(exported_out.shape, (3, 5)) + x2 = torch.randn(4, 4) + y2 = torch.randn(7, 4) + eager_out2 = Module()(x2, y2) + exported_out2 = ep.module()(x2, y2) + self.assertEqual(eager_out2, exported_out2) + self.assertEqual(exported_out2.shape, (4, 7)) + expected_mask = torch.ones(3, 5, dtype=torch.bool).tril(diagonal=2) + self.assertEqual(eager_out, expected_mask) + + def test_triu_dynamic_diagonal(self): + class Module(torch.nn.Module): + def forward(self, x, y): + x_len = x.shape[0] + y_len = y.shape[0] + mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device) + mask = mask.triu(diagonal=y_len - x_len) + return mask + + x = torch.randn(3, 4) + y = torch.randn(5, 4) + x_len = Dim("x_len", min=1, max=64) + y_len = Dim("y_len", min=1, max=64) + ep = export( + Module(), + (x, y), + dynamic_shapes={ + "x": {0: x_len}, + "y": {0: y_len}, + }, + ) + eager_out = Module()(x, y) + exported_out = ep.module()(x, y) + self.assertEqual(eager_out, exported_out) + self.assertEqual(exported_out.shape, (3, 5)) + x2 = torch.randn(4, 4) + y2 = torch.randn(7, 4) + eager_out2 = Module()(x2, y2) + exported_out2 = ep.module()(x2, y2) + self.assertEqual(eager_out2, exported_out2) + self.assertEqual(exported_out2.shape, (4, 7)) + expected_mask = torch.ones(3, 5, dtype=torch.bool).triu(diagonal=2) + self.assertEqual(eager_out, expected_mask) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ca098689026..4cd02ed35e9 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1073,8 +1073,8 @@ - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) LU_pivots: non_differentiable - L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" - U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril_symint(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril_symint(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu_symint() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu_symint()" output_differentiability: [False, True, True] - name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor @@ -1782,12 +1782,12 @@ self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) -- name: tril(Tensor self, int diagonal=0) -> Tensor - self: grad.tril(diagonal) +- name: tril(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.tril_symint(diagonal) result: auto_linear -- name: triu(Tensor self, int diagonal=0) -> Tensor - self: grad.triu(diagonal) +- name: triu(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.triu_symint(diagonal) result: auto_linear - name: trunc(Tensor self) -> Tensor