mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
llamarunner: Record the time for all batches during prompt processing
Currently, we only record the time for the last batch when processing the prompt. This results in unrealistically high numbers for the old llama runner. Before: total duration: 31.273112939s load duration: 4.97054657s prompt eval count: 32768 token(s) prompt eval duration: 235.137439ms prompt eval rate: 139356.80 tokens/s eval count: 1873 token(s) eval duration: 18.173182374s eval rate: 103.06 tokens/s After: total duration: 30.024798033s load duration: 4.758588663s prompt eval count: 32768 token(s) prompt eval duration: 7.779621548s prompt eval rate: 4212.03 tokens/s eval count: 1769 token(s) eval duration: 17.148014223s eval rate: 103.16 tokens/s
This commit is contained in:
parent
0334e67ffd
commit
a8d9c2648e
|
|
@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
defer s.mu.Unlock()
|
||||
|
||||
var batch *llama.Batch
|
||||
var numOutputs int
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
|
|
@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
break
|
||||
}
|
||||
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
output := i+1 == len(seq.inputs)
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
|
||||
if output {
|
||||
numOutputs++
|
||||
}
|
||||
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
|
|
@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
if numOutputs > 0 {
|
||||
s.lc.Synchronize()
|
||||
}
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
|
|
@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
seq.processingDuration += time.Since(t)
|
||||
continue
|
||||
}
|
||||
|
||||
s.lc.Synchronize()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user