llamarunner: Record the time for all batches during prompt processing

Currently, we only record the time for the last batch when processing
the prompt. This results in unrealistically high numbers for the
old llama runner.

Before:
total duration:       31.273112939s
load duration:        4.97054657s
prompt eval count:    32768 token(s)
prompt eval duration: 235.137439ms
prompt eval rate:     139356.80 tokens/s
eval count:           1873 token(s)
eval duration:        18.173182374s
eval rate:            103.06 tokens/s

After:
total duration:       30.024798033s
load duration:        4.758588663s
prompt eval count:    32768 token(s)
prompt eval duration: 7.779621548s
prompt eval rate:     4212.03 tokens/s
eval count:           1769 token(s)
eval duration:        17.148014223s
eval rate:            103.16 tokens/s
This commit is contained in:
Jesse Gross 2025-10-16 16:27:45 -07:00 committed by Jesse Gross
parent 0334e67ffd
commit a8d9c2648e

View File

@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
defer s.mu.Unlock()
var batch *llama.Batch
var numOutputs int
seqIdx := s.nextSeq - 1
for range s.seqs {
@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
break
}
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
output := i+1 == len(seq.inputs)
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
if output {
numOutputs++
}
seq.pendingInputs = append(seq.pendingInputs, input)
seq.iBatch = batch.NumTokens() - 1
}
@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return fmt.Errorf("failed to decode batch: %w", err)
}
if numOutputs > 0 {
s.lc.Synchronize()
}
for i, seq := range s.seqs {
if seq == nil {
continue
@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// don't sample prompt processing
if len(seq.inputs) != 0 {
seq.processingDuration += time.Since(t)
continue
}
s.lc.Synchronize()
seq.numDecoded++
if seq.numDecoded > 1 {
seq.generationDuration += time.Since(t)