Revert "server: Consolidate embedding truncation in runner (#12730)" (#12810)

This reverts commit 5d347f6d6f.
This commit is contained in:
Patrick Devine 2025-10-28 14:49:14 -07:00 committed by GitHub
parent 3d99d9779a
commit 29f63f37c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 264 deletions

View File

@ -4,9 +4,7 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@ -301,216 +299,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
t.Run("single input token count", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("batch parallel token counting", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"cat", "dog and mouse", "bird"},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("truncation single input", func(t *testing.T) {
truncTrue := true
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 50},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount > 50 {
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
}
if res.PromptEvalCount == 0 {
t.Fatal("expected non-zero token count after truncation")
}
})
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("truncate false error", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: strings.Repeat("word ", 100),
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
if !strings.Contains(err.Error(), "exceeds maximum context length") {
t.Fatalf("expected context length error, got: %v", err)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull the model if needed
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
}

View File

@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -1546,15 +1546,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
type EmbeddingRequest struct {
Content string `json:"content"`
Truncate bool `json:"truncate"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
}
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil {
@ -1563,54 +1561,51 @@ func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool)
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return nil, 0, err
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return nil, 0, err
return nil, err
} else if status != ServerStatusReady {
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
return nil, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil {
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
return nil, fmt.Errorf("error creating embed request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, 0, fmt.Errorf("do embedding request: %w", err)
return nil, fmt.Errorf("do embedding request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
return nil, fmt.Errorf("error reading embed response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body)
return nil, 0, api.StatusError{
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
return nil, fmt.Errorf("%s", body)
}
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return e.Embedding, e.PromptEvalCount, nil
return e.Embedding, nil
}
type TokenizeRequest struct {

View File

@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@ -759,7 +759,6 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@ -996,7 +996,6 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@ -21,7 +21,6 @@ import (
"os/signal"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"
@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input {
g.Go(func() error {
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil
})
}
if err := g.Wait(); err != nil {
var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
} else {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
}
return
}
@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
PromptEvalCount: count,
}
c.JSON(http.StatusOK, resp)
}
@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return

View File

@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
return s.embeddingResp, 0, s.embeddingRespErr
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {