ollamarunner: Worst case batch for token generation

We currently allocate the worst case batch for max sized
batches, which corresponds to prompt processing. However,
there are some cases where the generated graph is different
for small and large batches. To ensure that we don't need
to allocate memory later after layout has taken place, we
should run the worst case batch both ways and take the larger
amount of memory.

This does not noticeably affect loading speed as the most expensive
part of this logic is from image processing and that does not
occur during token generation.
This commit is contained in:
Jesse Gross 2025-10-27 16:31:58 -07:00 committed by Jesse Gross
parent 88236bc05f
commit 26465fb85f

View File

@ -1009,12 +1009,17 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) reserveWorstCaseGraph() error {
func (s *Server) reserveWorstCaseGraph(prompt bool) error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var err error
inputs := make([]*input.Input, s.batchSize)
batchSize := 1
if prompt {
batchSize = s.batchSize
}
inputs := make([]*input.Input, batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
@ -1031,7 +1036,7 @@ func (s *Server) reserveWorstCaseGraph() error {
// - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); prompt && ok {
mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close()
@ -1058,10 +1063,10 @@ func (s *Server) reserveWorstCaseGraph() error {
}
}
if len(inputs) < s.batchSize {
newInputs := make([]*input.Input, s.batchSize)
if len(inputs) < batchSize {
newInputs := make([]*input.Input, batchSize)
copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
for i := len(inputs); i < batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs
@ -1160,7 +1165,12 @@ func (s *Server) allocModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
return s.reserveWorstCaseGraph()
err = s.reserveWorstCaseGraph(true)
if err != nil {
return nil
}
return s.reserveWorstCaseGraph(false)
}
// closeModel frees all memory associated with a model