mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
ollamarunner: Worst case batch for token generation
We currently allocate the worst case batch for max sized batches, which corresponds to prompt processing. However, there are some cases where the generated graph is different for small and large batches. To ensure that we don't need to allocate memory later after layout has taken place, we should run the worst case batch both ways and take the larger amount of memory. This does not noticeably affect loading speed as the most expensive part of this logic is from image processing and that does not occur during token generation.
This commit is contained in:
parent
88236bc05f
commit
26465fb85f
|
|
@ -1009,12 +1009,17 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
func (s *Server) reserveWorstCaseGraph(prompt bool) error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
batchSize := 1
|
||||
if prompt {
|
||||
batchSize = s.batchSize
|
||||
}
|
||||
|
||||
inputs := make([]*input.Input, batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
|
|
@ -1031,7 +1036,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
// - The result may now be larger than a batch (images may not fit in a
|
||||
// single batch), so trim based on what will fit and must be grouped together.
|
||||
// - Fill out the rest of the space with text tokens.
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); prompt && ok {
|
||||
mmCtx := s.model.Backend().NewContext()
|
||||
defer mmCtx.Close()
|
||||
|
||||
|
|
@ -1058,10 +1063,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
if len(inputs) < batchSize {
|
||||
newInputs := make([]*input.Input, batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
for i := len(inputs); i < batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
|
|
@ -1160,7 +1165,12 @@ func (s *Server) allocModel(
|
|||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
return s.reserveWorstCaseGraph()
|
||||
err = s.reserveWorstCaseGraph(true)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.reserveWorstCaseGraph(false)
|
||||
}
|
||||
|
||||
// closeModel frees all memory associated with a model
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user