Fix Tril Triu SymInt (#166627)

Fixes #165613

### Summary:

- This MR fixes an issue where `torch.tril `and `torch.triu` with dynamic diagonal values cause torch.export to incorrectly infer unnecessary constraints between dynamic dimensions.
-  Ensured proper SymInt type annotations for diagonal parameter
-  Updated C++ implementation to correctly handle SymInt diagonal values.

### Impacts:
module: dynamic shapes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166627
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Parshant Sharma 2025-10-31 21:53:16 +00:00 committed by PyTorch MergeBot
parent dfebdcab86
commit 9970fb97ff
4 changed files with 84 additions and 16 deletions

View File

@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) {
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule( std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
const Tensor& self, const Tensor& self,
std::optional<int64_t> self_bdim, std::optional<int64_t> self_bdim,
int64_t diagonal = 0) { c10::SymInt diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions."); TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim); auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::tril(self_, diagonal); auto result = at::tril_symint(self_, std::move(diagonal));
return std::make_tuple(std::move(result), 0); return std::make_tuple(std::move(result), 0);
} }
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule( std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
const Tensor& self, const Tensor& self,
std::optional<int64_t> self_bdim, std::optional<int64_t> self_bdim,
int64_t diagonal = 0) { c10::SymInt diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions."); TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim); auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::triu(self_, diagonal); auto result = at::triu_symint(self_, std::move(diagonal));
return std::make_tuple(std::move(result), 0); return std::make_tuple(std::move(result), 0);
} }

View File

@ -8865,11 +8865,11 @@
autogen: bitwise_right_shift.Scalar_Tensor_out autogen: bitwise_right_shift.Scalar_Tensor_out
tags: pointwise tags: pointwise
- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) - func: tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
structured_delegate: tril.out structured_delegate: tril.out
variants: method variants: method
- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) - func: triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
structured_delegate: triu.out structured_delegate: triu.out
variants: method variants: method
@ -8993,25 +8993,25 @@
- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor - func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
variants: method, function variants: method, function
- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) - func: triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
structured: True structured: True
dispatch: dispatch:
CPU: triu_cpu CPU: triu_cpu
CUDA: triu_cuda CUDA: triu_cuda
MPS: triu_mps_out MPS: triu_mps_out
- func: triu(Tensor self, int diagonal=0) -> Tensor - func: triu(Tensor self, SymInt diagonal=0) -> Tensor
structured_delegate: triu.out structured_delegate: triu.out
variants: method, function variants: method, function
- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) - func: tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
structured: True structured: True
dispatch: dispatch:
CPU: tril_cpu CPU: tril_cpu
CUDA: tril_cuda CUDA: tril_cuda
MPS: tril_mps_out MPS: tril_mps_out
- func: tril(Tensor self, int diagonal=0) -> Tensor - func: tril(Tensor self, SymInt diagonal=0) -> Tensor
structured_delegate: tril.out structured_delegate: tril.out
variants: method, function variants: method, function

View File

@ -16745,6 +16745,74 @@ def forward(self, q, k, v):
self.assertEqual(result_non_strict, result_strict) self.assertEqual(result_non_strict, result_strict)
def test_tril_dynamic_diagonal(self):
class Module(torch.nn.Module):
def forward(self, x, y):
x_len = x.shape[0]
y_len = y.shape[0]
mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device)
mask = mask.tril(diagonal=y_len - x_len)
return mask
x = torch.randn(3, 4)
y = torch.randn(5, 4)
x_len = Dim("x_len", min=1, max=64)
y_len = Dim("y_len", min=1, max=64)
ep = export(
Module(),
(x, y),
dynamic_shapes={
"x": {0: x_len},
"y": {0: y_len},
},
)
eager_out = Module()(x, y)
exported_out = ep.module()(x, y)
self.assertEqual(eager_out, exported_out)
self.assertEqual(exported_out.shape, (3, 5))
x2 = torch.randn(4, 4)
y2 = torch.randn(7, 4)
eager_out2 = Module()(x2, y2)
exported_out2 = ep.module()(x2, y2)
self.assertEqual(eager_out2, exported_out2)
self.assertEqual(exported_out2.shape, (4, 7))
expected_mask = torch.ones(3, 5, dtype=torch.bool).tril(diagonal=2)
self.assertEqual(eager_out, expected_mask)
def test_triu_dynamic_diagonal(self):
class Module(torch.nn.Module):
def forward(self, x, y):
x_len = x.shape[0]
y_len = y.shape[0]
mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device)
mask = mask.triu(diagonal=y_len - x_len)
return mask
x = torch.randn(3, 4)
y = torch.randn(5, 4)
x_len = Dim("x_len", min=1, max=64)
y_len = Dim("y_len", min=1, max=64)
ep = export(
Module(),
(x, y),
dynamic_shapes={
"x": {0: x_len},
"y": {0: y_len},
},
)
eager_out = Module()(x, y)
exported_out = ep.module()(x, y)
self.assertEqual(eager_out, exported_out)
self.assertEqual(exported_out.shape, (3, 5))
x2 = torch.randn(4, 4)
y2 = torch.randn(7, 4)
eager_out2 = Module()(x2, y2)
exported_out2 = ep.module()(x2, y2)
self.assertEqual(eager_out2, exported_out2)
self.assertEqual(exported_out2.shape, (4, 7))
expected_mask = torch.ones(3, 5, dtype=torch.bool).triu(diagonal=2)
self.assertEqual(eager_out, expected_mask)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase): class TestOneOffModelExportResult(TestCase):

View File

@ -1073,8 +1073,8 @@
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1))
LU_pivots: non_differentiable LU_pivots: non_differentiable
L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril_symint(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril_symint(-1)"
U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu_symint() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu_symint()"
output_differentiability: [False, True, True] output_differentiability: [False, True, True]
- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor - name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
@ -1782,12 +1782,12 @@
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
- name: tril(Tensor self, int diagonal=0) -> Tensor - name: tril(Tensor self, SymInt diagonal=0) -> Tensor
self: grad.tril(diagonal) self: grad.tril_symint(diagonal)
result: auto_linear result: auto_linear
- name: triu(Tensor self, int diagonal=0) -> Tensor - name: triu(Tensor self, SymInt diagonal=0) -> Tensor
self: grad.triu(diagonal) self: grad.triu_symint(diagonal)
result: auto_linear result: auto_linear
- name: trunc(Tensor self) -> Tensor - name: trunc(Tensor self) -> Tensor