mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix Tril Triu SymInt (#166627)
Fixes #165613 ### Summary: - This MR fixes an issue where `torch.tril `and `torch.triu` with dynamic diagonal values cause torch.export to incorrectly infer unnecessary constraints between dynamic dimensions. - Ensured proper SymInt type annotations for diagonal parameter - Updated C++ implementation to correctly handle SymInt diagonal values. ### Impacts: module: dynamic shapes Pull Request resolved: https://github.com/pytorch/pytorch/pull/166627 Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
parent
dfebdcab86
commit
9970fb97ff
|
|
@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) {
|
||||||
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t diagonal = 0) {
|
c10::SymInt diagonal = 0) {
|
||||||
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
|
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
|
||||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||||
auto result = at::tril(self_, diagonal);
|
auto result = at::tril_symint(self_, std::move(diagonal));
|
||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t diagonal = 0) {
|
c10::SymInt diagonal = 0) {
|
||||||
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
|
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
|
||||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||||
auto result = at::triu(self_, diagonal);
|
auto result = at::triu_symint(self_, std::move(diagonal));
|
||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8865,11 +8865,11 @@
|
||||||
autogen: bitwise_right_shift.Scalar_Tensor_out
|
autogen: bitwise_right_shift.Scalar_Tensor_out
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
|
- func: tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
|
||||||
structured_delegate: tril.out
|
structured_delegate: tril.out
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
|
- func: triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
|
||||||
structured_delegate: triu.out
|
structured_delegate: triu.out
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
|
|
@ -8993,25 +8993,25 @@
|
||||||
- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
|
- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
|
||||||
variants: method, function
|
variants: method, function
|
||||||
|
|
||||||
- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
- func: triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: triu_cpu
|
CPU: triu_cpu
|
||||||
CUDA: triu_cuda
|
CUDA: triu_cuda
|
||||||
MPS: triu_mps_out
|
MPS: triu_mps_out
|
||||||
|
|
||||||
- func: triu(Tensor self, int diagonal=0) -> Tensor
|
- func: triu(Tensor self, SymInt diagonal=0) -> Tensor
|
||||||
structured_delegate: triu.out
|
structured_delegate: triu.out
|
||||||
variants: method, function
|
variants: method, function
|
||||||
|
|
||||||
- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
- func: tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: tril_cpu
|
CPU: tril_cpu
|
||||||
CUDA: tril_cuda
|
CUDA: tril_cuda
|
||||||
MPS: tril_mps_out
|
MPS: tril_mps_out
|
||||||
|
|
||||||
- func: tril(Tensor self, int diagonal=0) -> Tensor
|
- func: tril(Tensor self, SymInt diagonal=0) -> Tensor
|
||||||
structured_delegate: tril.out
|
structured_delegate: tril.out
|
||||||
variants: method, function
|
variants: method, function
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16745,6 +16745,74 @@ def forward(self, q, k, v):
|
||||||
|
|
||||||
self.assertEqual(result_non_strict, result_strict)
|
self.assertEqual(result_non_strict, result_strict)
|
||||||
|
|
||||||
|
def test_tril_dynamic_diagonal(self):
|
||||||
|
class Module(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
x_len = x.shape[0]
|
||||||
|
y_len = y.shape[0]
|
||||||
|
mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device)
|
||||||
|
mask = mask.tril(diagonal=y_len - x_len)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
y = torch.randn(5, 4)
|
||||||
|
x_len = Dim("x_len", min=1, max=64)
|
||||||
|
y_len = Dim("y_len", min=1, max=64)
|
||||||
|
ep = export(
|
||||||
|
Module(),
|
||||||
|
(x, y),
|
||||||
|
dynamic_shapes={
|
||||||
|
"x": {0: x_len},
|
||||||
|
"y": {0: y_len},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eager_out = Module()(x, y)
|
||||||
|
exported_out = ep.module()(x, y)
|
||||||
|
self.assertEqual(eager_out, exported_out)
|
||||||
|
self.assertEqual(exported_out.shape, (3, 5))
|
||||||
|
x2 = torch.randn(4, 4)
|
||||||
|
y2 = torch.randn(7, 4)
|
||||||
|
eager_out2 = Module()(x2, y2)
|
||||||
|
exported_out2 = ep.module()(x2, y2)
|
||||||
|
self.assertEqual(eager_out2, exported_out2)
|
||||||
|
self.assertEqual(exported_out2.shape, (4, 7))
|
||||||
|
expected_mask = torch.ones(3, 5, dtype=torch.bool).tril(diagonal=2)
|
||||||
|
self.assertEqual(eager_out, expected_mask)
|
||||||
|
|
||||||
|
def test_triu_dynamic_diagonal(self):
|
||||||
|
class Module(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
x_len = x.shape[0]
|
||||||
|
y_len = y.shape[0]
|
||||||
|
mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device)
|
||||||
|
mask = mask.triu(diagonal=y_len - x_len)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
y = torch.randn(5, 4)
|
||||||
|
x_len = Dim("x_len", min=1, max=64)
|
||||||
|
y_len = Dim("y_len", min=1, max=64)
|
||||||
|
ep = export(
|
||||||
|
Module(),
|
||||||
|
(x, y),
|
||||||
|
dynamic_shapes={
|
||||||
|
"x": {0: x_len},
|
||||||
|
"y": {0: y_len},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eager_out = Module()(x, y)
|
||||||
|
exported_out = ep.module()(x, y)
|
||||||
|
self.assertEqual(eager_out, exported_out)
|
||||||
|
self.assertEqual(exported_out.shape, (3, 5))
|
||||||
|
x2 = torch.randn(4, 4)
|
||||||
|
y2 = torch.randn(7, 4)
|
||||||
|
eager_out2 = Module()(x2, y2)
|
||||||
|
exported_out2 = ep.module()(x2, y2)
|
||||||
|
self.assertEqual(eager_out2, exported_out2)
|
||||||
|
self.assertEqual(exported_out2.shape, (4, 7))
|
||||||
|
expected_mask = torch.ones(3, 5, dtype=torch.bool).triu(diagonal=2)
|
||||||
|
self.assertEqual(eager_out, expected_mask)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||||
class TestOneOffModelExportResult(TestCase):
|
class TestOneOffModelExportResult(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -1073,8 +1073,8 @@
|
||||||
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
|
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
|
||||||
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1))
|
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1))
|
||||||
LU_pivots: non_differentiable
|
LU_pivots: non_differentiable
|
||||||
L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)"
|
L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril_symint(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril_symint(-1)"
|
||||||
U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()"
|
U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu_symint() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu_symint()"
|
||||||
output_differentiability: [False, True, True]
|
output_differentiability: [False, True, True]
|
||||||
|
|
||||||
- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
|
- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
|
||||||
|
|
@ -1782,12 +1782,12 @@
|
||||||
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
|
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
|
||||||
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
|
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
|
||||||
|
|
||||||
- name: tril(Tensor self, int diagonal=0) -> Tensor
|
- name: tril(Tensor self, SymInt diagonal=0) -> Tensor
|
||||||
self: grad.tril(diagonal)
|
self: grad.tril_symint(diagonal)
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
||||||
- name: triu(Tensor self, int diagonal=0) -> Tensor
|
- name: triu(Tensor self, SymInt diagonal=0) -> Tensor
|
||||||
self: grad.triu(diagonal)
|
self: grad.triu_symint(diagonal)
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
||||||
- name: trunc(Tensor self) -> Tensor
|
- name: trunc(Tensor self) -> Tensor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user