embed: cleanup (#12299)

* cleanup

* use pooling.TypeNone

* pooling test
This commit is contained in:
Michael Yang 2025-09-16 09:48:42 -07:00 committed by GitHub
parent a1cff89b30
commit c253433d68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 104 additions and 19 deletions

View File

@ -11,26 +11,32 @@ const (
TypeMean
TypeCLS
TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
)
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch poolingType {
case TypeNone:
return hiddenStates
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
panic("not implemented")
case TypeRank:
panic("not implemented")
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
return hiddenStates
default:
panic("not implemented")
panic("unknown pooling type")
}
}

View File

@ -0,0 +1,79 @@
package pooling_test
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/discover"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn/pooling"
)
func setup(tb testing.TB, n int) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
if err := fsggml.WriteGGUF(f, fsggml.KV{
"general.architecture": "test",
"test.block_count": uint32(1),
}, []*fsggml.Tensor{
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
}); err != nil {
tb.Fatal(err)
}
var gpuLayers ml.GPULayersList
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
gpuLayers = append(gpuLayers, ml.GPULayers{
ID: gpus[0].ID,
Layers: slices.Collect(func(yield func(int) bool) {
for i := range n {
if !yield(i) {
return
}
}
}),
})
}
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
if err != nil {
tb.Fatal(err)
}
return b
}
func TestForward(t *testing.T) {
cases := map[pooling.Type][]float32{
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
}
for typ, want := range cases {
t.Run(typ.String(), func(t *testing.T) {
b := setup(t, 99)
defer b.Close()
ctx := b.NewContext()
defer ctx.Close()
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
tt = typ.Forward(ctx, tt)
ctx.Forward(tt).Compute(tt)
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
t.Error(diff)
}
})
}
}

View File

@ -5,7 +5,6 @@ import (
"fmt"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"reflect"
"strconv"
@ -21,6 +20,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
)
@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
}
arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}

View File

@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}

View File

@ -22,7 +22,7 @@ type embedModel struct {
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}

View File

@ -11,7 +11,6 @@ import (
"image"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
@ -32,6 +31,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
for {
@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}