Revert "server: Consolidate embedding truncation in runner (#12730)" (#12810)

This reverts commit 5d347f6d6f.
This commit is contained in:
Patrick Devine 2025-10-28 14:49:14 -07:00 committed by GitHub
parent 3d99d9779a
commit 29f63f37c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 264 deletions

View File

@ -4,9 +4,7 @@ package integration
import ( import (
"context" "context"
"errors"
"math" "math"
"strings"
"testing" "testing"
"time" "time"
@ -301,216 +299,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req) return client.Embed(ctx, &req)
} }
func TestEmbedTruncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
t.Run("single input token count", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("batch parallel token counting", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"cat", "dog and mouse", "bird"},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("truncation single input", func(t *testing.T) {
truncTrue := true
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 50},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount > 50 {
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
}
if res.PromptEvalCount == 0 {
t.Fatal("expected non-zero token count after truncation")
}
})
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("truncate false error", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: strings.Repeat("word ", 100),
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
if !strings.Contains(err.Error(), "exceeds maximum context length") {
t.Fatalf("expected context length error, got: %v", err)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull the model if needed
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
}

View File

@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -1545,16 +1545,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
Content string `json:"content"` Content string `json:"content"`
Truncate bool `json:"truncate"`
} }
type EmbeddingResponse struct { type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
} }
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
logutil.Trace("embedding request", "input", input) logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
@ -1563,54 +1561,51 @@ func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool)
} else { } else {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
} }
return nil, 0, err return nil, err
} }
defer s.sem.Release(1) defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatusRetry(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, 0, fmt.Errorf("unexpected server status: %s", status) return nil, fmt.Errorf("unexpected server status: %s", status)
} }
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate}) data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error creating embed request: %w", err) return nil, fmt.Errorf("error creating embed request: %w", err)
} }
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r) resp, err := http.DefaultClient.Do(r)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("do embedding request: %w", err) return nil, fmt.Errorf("do embedding request: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error reading embed response: %w", err) return nil, fmt.Errorf("error reading embed response: %w", err)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body) log.Printf("llm embedding error: %s", body)
return nil, 0, api.StatusError{ return nil, fmt.Errorf("%s", body)
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
} }
var e EmbeddingResponse var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil { if err := json.Unmarshal(body, &e); err != nil {
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }
return e.Embedding, e.PromptEvalCount, nil return e.Embedding, nil
} }
type TokenizeRequest struct { type TokenizeRequest struct {

View File

@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true, embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }
@ -758,8 +758,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding, Embedding: embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }

View File

@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true, embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }
@ -995,8 +995,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
} }
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding, Embedding: <-seq.embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }

View File

@ -21,7 +21,6 @@ import (
"os/signal" "os/signal"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group var g errgroup.Group
embeddings := make([][]float32, len(input)) embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input { for i, text := range input {
g.Go(func() error { g.Go(func() error {
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate) embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil { if err != nil {
return err return err
} }
@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions]) embedding = normalize(embedding[:req.Dimensions])
} }
embeddings[i] = embedding embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil return nil
}) })
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
var serr api.StatusError c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
} else {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
}
return return
} }
@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings, Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart), TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens), PromptEvalCount: count,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return return

View File

@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embeddingResp, 0, s.embeddingRespErr return s.embeddingResp, s.embeddingRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {