mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
This reverts commit 5d347f6d6f.
This commit is contained in:
parent
3d99d9779a
commit
29f63f37c8
|
|
@ -4,9 +4,7 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -301,216 +299,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
func TestEmbedTruncation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
t.Run("single input token count", func(t *testing.T) {
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch parallel token counting", func(t *testing.T) {
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"cat", "dog and mouse", "bird"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncation single input", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: longInput,
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 50},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 50 {
|
||||
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount == 0 {
|
||||
t.Fatal("expected non-zero token count after truncation")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncation batch", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 30},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 90 {
|
||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncate false error", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: strings.Repeat("word ", 100),
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(ctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "exceeds maximum context length") {
|
||||
t.Fatalf("expected context length error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
|
||||
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"test", "test", "test"},
|
||||
}
|
||||
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCount := baseRes.PromptEvalCount * 3
|
||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
// where api.StatusError errors should maintain their original status code.
|
||||
func TestEmbedStatusCode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Pull the model if needed
|
||||
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: longInput,
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(ctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||
// not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message is meaningful
|
||||
if !strings.Contains(err.Error(), "context length") {
|
||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{
|
||||
"short input",
|
||||
strings.Repeat("very long input ", 100),
|
||||
"another short input",
|
||||
},
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(ctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ type LlamaServer interface {
|
|||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
|
|
@ -1545,16 +1545,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
Truncate bool `json:"truncate"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
logutil.Trace("embedding request", "input", input)
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
|
|
@ -1563,54 +1561,51 @@ func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool)
|
|||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("do embedding request: %w", err)
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm embedding error: %s", body)
|
||||
return nil, 0, api.StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorMessage: string(body),
|
||||
}
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var e EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &e); err != nil {
|
||||
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return e.Embedding, e.PromptEvalCount, nil
|
||||
return e.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
|
|
|||
|
|
@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
truncate: req.Truncate,
|
||||
|
||||
// TODO (jmorganca): this should be provided by the server via the
|
||||
// request options and truncated here in the runner, instead of relying on
|
||||
// the server's truncate logic
|
||||
truncate: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -758,8 +758,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
truncate: req.Truncate,
|
||||
|
||||
// TODO (jmorganca): this should be provided by the server via the
|
||||
// request options and truncated here in the runner, instead of relying on
|
||||
// the server's truncate logic
|
||||
truncate: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -995,8 +995,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
Embedding: <-seq.embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
|
@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
|
|
@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
kvData, _, err := getModelData(m.ModelPath, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||
if ctxLen <= 0 {
|
||||
// return error if the truncated input would be empty or just special tokens
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
var totalTokens uint64
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
embedding = normalize(embedding[:req.Dimensions])
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: int(totalTokens),
|
||||
PromptEvalCount: count,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
|
@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
return
|
||||
|
|
|
|||
|
|
@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
|||
return s.completionResp
|
||||
}
|
||||
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||
return s.embeddingResp, 0, s.embeddingRespErr
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
}
|
||||
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user