Add nn.Bilinear param validation (#149018)

Fixes #103425

## Changes

- Add doc description size value `must be > 0`
- Add validation for `in1_features` param

Currently, only `in1_features` will cause runtime error, if add checks for `in2_features` and `out_features` as well, might be kind of BC breaking.

```python
import torch
from torch import nn

class lenet(nn.Module):
    def __init__(self):
        super(lenet, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)

        # Error, `in1_features=1, in2_features=0, out_features=0` no error
        self.linear = nn.Bilinear(in1_features=0, in2_features=0, out_features=0)

    def forward(self, x):
        # 1st block
        x = self.conv(x)
        x = self.linear(x)

        return x

if __name__ == '__main__':
    net = lenet()

```

## Test Result

```bash
pytest test/test_nn.py -k test_bilinear -vv
```

![image](https://github.com/user-attachments/assets/20617ba9-bac5-4db2-aecc-1831dbc8eb43)

![image](https://github.com/user-attachments/assets/401e4e1f-051a-4e1c-952b-48e85de64b0b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149018
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
zeshengzong 2025-03-14 19:26:08 +00:00 committed by PyTorch MergeBot
parent 5a843f8973
commit a7f8de2198
2 changed files with 9 additions and 3 deletions

View File

@ -6837,6 +6837,10 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8) expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
self.assertEqual(expected, m(input1, input2)) self.assertEqual(expected, m(input1, input2))
def test_bilinear_value_error(self):
with self.assertRaisesRegex(ValueError, "in1_features must be > 0"):
nn.Bilinear(0, 0, 0)
def test_fold_invalid_arg(self): def test_fold_invalid_arg(self):
# input.size(1) not divisible by \prod(kernel_size) # input.size(1) not divisible by \prod(kernel_size)

View File

@ -151,9 +151,9 @@ class Bilinear(Module):
r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`. r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
Args: Args:
in1_features: size of each first input sample in1_features: size of each first input sample, must be > 0
in2_features: size of each second input sample in2_features: size of each second input sample, must be > 0
out_features: size of each output sample out_features: size of each output sample, must be > 0
bias: If set to ``False``, the layer will not learn an additive bias. bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True`` Default: ``True``
@ -202,6 +202,8 @@ class Bilinear(Module):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
if in1_features <= 0:
raise ValueError(f"in1_features must be > 0, but got {in1_features}")
self.in1_features = in1_features self.in1_features = in1_features
self.in2_features = in2_features self.in2_features = in2_features
self.out_features = out_features self.out_features = out_features