diff --git a/convert/convert_bert.go b/convert/convert_bert.go index a9f4b8a7..6b0d0030 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -28,6 +28,7 @@ type bertModel struct { LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` + normalizeEmbeddings bool PoolingType uint32 } @@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error { var pooling string for _, m := range modules { - if m.Type == "sentence_transformers.models.Pooling" { + switch m.Type { + case "sentence_transformers.models.Pooling": pooling = m.Path - break + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true } } @@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV { kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false kv["bert.pooling_type"] = p.PoolingType + kv["bert.normalize_embeddings"] = p.normalizeEmbeddings kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) diff --git a/ml/backend.go b/ml/backend.go index 154a0f1b..ef756478 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -416,6 +416,7 @@ type Tensor interface { AddID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 931386d5..d5e2e9c9 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { } } +func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) if w != nil { diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go new file mode 100644 index 00000000..f84690c4 --- /dev/null +++ b/ml/nn/pooling/pooling.go @@ -0,0 +1,36 @@ +package pooling + +import ( + "github.com/ollama/ollama/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast + TypeRank + + TypeUnknown = 0xFFFFFFFE + TypeUnspecified = 0xFFFFFFFF +) + +func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor { + switch poolingType { + case TypeNone: + return hiddenStates + case TypeMean: + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + case TypeCLS: + return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + case TypeLast: + panic("not implemented") + case TypeRank: + panic("not implemented") + default: + panic("not implemented") + } +} diff --git a/model/model.go b/model/model.go index 3a72f09a..efef71d8 100644 --- a/model/model.go +++ b/model/model.go @@ -24,7 +24,11 @@ import ( "github.com/ollama/ollama/model/input" ) -var ErrNoVisionModel = errors.New("this model is missing data required for image input") +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { @@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { vv = vv.Elem() } - vv = vv.Elem() + vv = reflect.Indirect(vv) if v.IsNil() { vv = reflect.New(v.Type().Elem()).Elem() } diff --git a/model/models/bert/model.go b/model/models/bert/model.go new file mode 100644 index 00000000..fd1dbd77 --- /dev/null +++ b/model/models/bert/model.go @@ -0,0 +1,181 @@ +package bert + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` + + Layers []EncoderLayer `gguf:"blk"` + + Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) + } + + hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + if m.normalize { + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + } + + return hiddenStates, nil +} + +type EncoderLayer struct { + *Attention + AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` + + *MLP + MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + // Attention + residual := hiddenStates + hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // MLP + residual = hiddenStates + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + + return hiddenStates +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"` + + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"` + + Value *nn.Linear `gguf:"attn_v"` + + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := a.Query.Forward(ctx, hiddenStates) + if a.QueryNorm != nil { + query = a.QueryNorm.Forward(ctx, query, opts.eps) + } + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + + key := a.Key.Forward(ctx, hiddenStates) + if a.KeyNorm != nil { + key = a.KeyNorm.Forward(ctx, key, opts.eps) + } + key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + value := a.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return a.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + poolingType pooling.Type + eps float32 + normalize bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model", "bert") { + case "bert": + processor = model.NewWordPiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + }, + ) + default: + return nil, model.ErrUnsupportedTokenizer + } + + return &Model{ + TextProcessor: processor, + Layers: make([]EncoderLayer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", true), + }, + }, nil +} + +func init() { + model.Register("bert", New) + model.Register("bert_embed", New) +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 84c89e1f..8ccb9f92 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -24,7 +24,7 @@ type Options struct { type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -40,7 +40,7 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 7d1e269f..395a0344 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -1,48 +1,38 @@ package gemma3 import ( - "errors" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type embedModel struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel - PoolingType uint32 + poolingType pooling.Type Dense [2]*nn.Linear `gguf:"dense"` } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - - switch m.PoolingType { - case 0: // None - case 1: // Mean - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - default: - return nil, errors.New("unsupported pooling type") - } - + hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } - + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) return hiddenStates, nil } func newEmbedModel(c fs.Config) (model.Model, error) { m := &embedModel{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -60,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) { }, ), TextModel: newTextModel(c), - PoolingType: c.Uint("pooling_type", 0), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), } m.Cache = kvcache.NewWrapperCache( diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 5c92b6bf..27da889e 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *VisionModel `gguf:"v"` *TextModel @@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go index 6e83a972..e59e3193 100644 --- a/model/models/gemma3n/model.go +++ b/model/models/gemma3n/model.go @@ -10,7 +10,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel } @@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func New(c fs.Config) (model.Model, error) { m := Model{ TextModel: newTextModel(c), - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/models.go b/model/models/models.go index c880a472..cc998078 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,7 @@ package models import ( + _ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 827ce00d..db07beee 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -12,18 +12,18 @@ import ( const spmWhitespaceSep = "▁" -type SentencePieceModel struct { +type SentencePiece struct { maxTokenLen int vocab *Vocabulary } -var _ TextProcessor = (*SentencePieceModel)(nil) +var _ TextProcessor = (*SentencePiece)(nil) -func (spm SentencePieceModel) Vocabulary() *Vocabulary { +func (spm SentencePiece) Vocabulary() *Vocabulary { return spm.vocab } -func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { +func NewSentencePiece(vocab *Vocabulary) SentencePiece { logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "max token len", maxTokenLen) - return SentencePieceModel{ + return SentencePiece{ maxTokenLen: maxTokenLen, vocab: vocab, } } -func (spm SentencePieceModel) Is(id int32, special Special) bool { +func (spm SentencePiece) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { id := spm.vocab.Encode(special) @@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} { return item } -func (spm SentencePieceModel) Decode(ids []int32) (string, error) { +func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { data := spm.vocab.Decode(id) diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go index 50ac2678..8f4570c1 100644 --- a/model/sentencepiece_test.go +++ b/model/sentencepiece_test.go @@ -12,7 +12,7 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func loadSentencePieceVocab(t *testing.T) SentencePieceModel { +func loadSentencePieceVocab(t *testing.T) SentencePiece { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) @@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(&v) + return NewSentencePiece(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) { }) } -func TestSentencePieceModelDecodeByteTokens(t *testing.T) { +func TestSentencePieceDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", @@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { Scores: []float32{0, 0, 0, 0, 0}, } - spm := NewSentencePieceModel(vocab) + spm := NewSentencePiece(vocab) tests := []struct { name string diff --git a/model/wordpiece.go b/model/wordpiece.go new file mode 100644 index 00000000..e8d5e848 --- /dev/null +++ b/model/wordpiece.go @@ -0,0 +1,167 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + // TODO: some models might not want [ToLower] + piece = wpm.vocab.Encode(strings.ToLower(subword)) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial && len(ids) > 0 { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary) WordPiece { + return WordPiece{ + vocab: vocab, + } +} diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go new file mode 100644 index 00000000..258fbffc --- /dev/null +++ b/model/wordpiece_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} diff --git a/server/routes.go b/server/routes.go index 5114cb74..739ce69d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -488,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { } truncate := true - if req.Truncate != nil && !*req.Truncate { truncate = false } @@ -555,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})