mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
fix: conv2d bias (#12834)
This commit is contained in:
parent
93e45f0f0d
commit
0d140bd1af
|
|
@ -10,7 +10,8 @@ type Conv2D struct {
|
||||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
|
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
|
||||||
if m.Bias != nil {
|
if m.Bias != nil {
|
||||||
t = t.Add(ctx, m.Bias)
|
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||||
|
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||||
}
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user