mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add nn.Bilinear param validation (#149018)
Fixes #103425 ## Changes - Add doc description size value `must be > 0` - Add validation for `in1_features` param Currently, only `in1_features` will cause runtime error, if add checks for `in2_features` and `out_features` as well, might be kind of BC breaking. ```python import torch from torch import nn class lenet(nn.Module): def __init__(self): super(lenet, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1) # Error, `in1_features=1, in2_features=0, out_features=0` no error self.linear = nn.Bilinear(in1_features=0, in2_features=0, out_features=0) def forward(self, x): # 1st block x = self.conv(x) x = self.linear(x) return x if __name__ == '__main__': net = lenet() ``` ## Test Result ```bash pytest test/test_nn.py -k test_bilinear -vv ```   Pull Request resolved: https://github.com/pytorch/pytorch/pull/149018 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
5a843f8973
commit
a7f8de2198
|
|
@ -6837,6 +6837,10 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||||
expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
|
expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
|
||||||
self.assertEqual(expected, m(input1, input2))
|
self.assertEqual(expected, m(input1, input2))
|
||||||
|
|
||||||
|
def test_bilinear_value_error(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, "in1_features must be > 0"):
|
||||||
|
nn.Bilinear(0, 0, 0)
|
||||||
|
|
||||||
def test_fold_invalid_arg(self):
|
def test_fold_invalid_arg(self):
|
||||||
# input.size(1) not divisible by \prod(kernel_size)
|
# input.size(1) not divisible by \prod(kernel_size)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -151,9 +151,9 @@ class Bilinear(Module):
|
||||||
r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
|
r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in1_features: size of each first input sample
|
in1_features: size of each first input sample, must be > 0
|
||||||
in2_features: size of each second input sample
|
in2_features: size of each second input sample, must be > 0
|
||||||
out_features: size of each output sample
|
out_features: size of each output sample, must be > 0
|
||||||
bias: If set to ``False``, the layer will not learn an additive bias.
|
bias: If set to ``False``, the layer will not learn an additive bias.
|
||||||
Default: ``True``
|
Default: ``True``
|
||||||
|
|
||||||
|
|
@ -202,6 +202,8 @@ class Bilinear(Module):
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if in1_features <= 0:
|
||||||
|
raise ValueError(f"in1_features must be > 0, but got {in1_features}")
|
||||||
self.in1_features = in1_features
|
self.in1_features = in1_features
|
||||||
self.in2_features = in2_features
|
self.in2_features = in2_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user