diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go index f84690c4..63b63b3a 100644 --- a/ml/nn/pooling/pooling.go +++ b/ml/nn/pooling/pooling.go @@ -11,26 +11,32 @@ const ( TypeMean TypeCLS TypeLast - TypeRank - - TypeUnknown = 0xFFFFFFFE - TypeUnspecified = 0xFFFFFFFF ) -func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor { - switch poolingType { - case TypeNone: - return hiddenStates +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { case TypeMean: hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) case TypeCLS: return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) case TypeLast: - panic("not implemented") - case TypeRank: - panic("not implemented") + hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) + return hiddenStates default: - panic("not implemented") + panic("unknown pooling type") } } diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go new file mode 100644 index 00000000..c8001945 --- /dev/null +++ b/ml/nn/pooling/pooling_test.go @@ -0,0 +1,79 @@ +package pooling_test + +import ( + "bytes" + "os" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/discover" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn/pooling" +) + +func setup(tb testing.TB, n int) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*fsggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))}, + }); err != nil { + tb.Fatal(err) + } + + var gpuLayers ml.GPULayersList + if gpus := discover.GetGPUInfo(); len(gpus) > 0 { + gpuLayers = append(gpuLayers, ml.GPULayers{ + ID: gpus[0].ID, + Layers: slices.Collect(func(yield func(int) bool) { + for i := range n { + if !yield(i) { + return + } + } + }), + }) + } + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +func TestForward(t *testing.T) { + cases := map[pooling.Type][]float32{ + pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11}, + pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7}, + pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15}, + } + for typ, want := range cases { + t.Run(typ.String(), func(t *testing.T) { + b := setup(t, 99) + defer b.Close() + + ctx := b.NewContext() + defer ctx.Close() + + tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2) + tt = typ.Forward(ctx, tt) + + ctx.Forward(tt).Compute(tt) + if diff := cmp.Diff(want, tt.Floats()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/model/model.go b/model/model.go index efef71d8..5493a4e6 100644 --- a/model/model.go +++ b/model/model.go @@ -5,7 +5,6 @@ import ( "fmt" _ "image/jpeg" _ "image/png" - "math" "os" "reflect" "strconv" @@ -21,6 +20,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model/input" ) @@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { } arch := b.Config().Architecture() - if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { + if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone { arch = arch + "_embed" } diff --git a/model/models/bert/model.go b/model/models/bert/embed.go similarity index 98% rename from model/models/bert/model.go rename to model/models/bert/embed.go index fd1dbd77..166c11e1 100644 --- a/model/models/bert/model.go +++ b/model/models/bert/embed.go @@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) } - hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) if m.normalize { hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) } diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 395a0344..52554776 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -22,7 +22,7 @@ type embedModel struct { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 3a32384f..480cfc19 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -11,7 +11,6 @@ import ( "image" "log" "log/slog" - "math" "net" "net/http" "os" @@ -32,6 +31,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" @@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 + supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone var activeBatch batchState for { @@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { + if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) return }