ggml: Seperate tensor load from backend creation

Currently, when the backend is created, the tensors are loaded at the
same time, which is a slow operation. This separates them to be two
steps:
 - Create backend, including enumerating tensors and memory allocation
 - Loading tensor data

This allows more flexibility in managing model loading.
This commit is contained in:
Jesse Gross 2025-04-17 13:42:40 -07:00 committed by Jesse Gross
parent d755577473
commit 94ab428e3f
13 changed files with 131 additions and 115 deletions

View File

@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
} }
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, -1) m, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
} }
defer r.Close() defer r.Close()
m, _, err := ggml.Decode(r, -1) m, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -15,6 +15,7 @@ import (
type GGML struct { type GGML struct {
container container
model model
Length int64
} }
type model interface { type model interface {
@ -386,12 +387,12 @@ func DetectContentType(b []byte) string {
// //
// It collects array values for arrays with a size less than or equal to // It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected. // maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
rs = bufioutil.NewBufferedSeeker(rs, 32<<10) rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, 0, err return nil, err
} }
var c container var c container
@ -401,24 +402,25 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
case FILE_MAGIC_GGUF_BE: case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default: default:
return nil, 0, errors.New("invalid file magic") return nil, errors.New("invalid file magic")
} }
model, err := c.Decode(rs) model, err := c.Decode(rs)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
offset, err := rs.Seek(0, io.SeekCurrent) offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
// final model type // final model type
return &GGML{ return &GGML{
container: c, container: c,
model: model, model: model,
}, offset, nil Length: offset,
}, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {

View File

@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) {
} }
defer r.Close() defer r.Close()
ff, _, err := Decode(r, 0) ff, err := Decode(r, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -423,7 +423,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) {
} }
defer file.Close() defer file.Close()
ggml, _, err := ggml.Decode(file, 1024) ggml, err := ggml.Decode(file, 1024)
if err != nil { if err != nil {
return 0 return 0
} }

View File

@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
} }
defer f.Close() defer f.Close()
ggml, _, err := ggml.Decode(f, maxArraySize) ggml, err := ggml.Decode(f, maxArraySize)
return ggml, err return ggml, err
} }

View File

@ -6,7 +6,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math" "math"
"os"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -15,6 +14,7 @@ import (
) )
type Backend interface { type Backend interface {
Load(ctx context.Context, progress func(float32)) error
Config() fs.Config Config() fs.Config
Get(name string) Tensor Get(name string) Tensor
NewContext() Context NewContext() Context
@ -52,10 +52,6 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models // BackendParams controls how the backend loads and executes models
type BackendParams struct { type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU // NumThreads sets the number of threads to use if running on the CPU
NumThreads int NumThreads int
@ -72,9 +68,9 @@ type BackendParams struct {
FlashAttention bool FlashAttention bool
} }
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error)) var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) { func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok { if _, ok := backends[name]; ok {
panic("backend: backend already registered") panic("backend: backend already registered")
} }
@ -82,9 +78,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
backends[name] = f backends[name] = f
} }
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) { func NewBackend(modelPath string, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok { if backend, ok := backends["ggml"]; ok {
return backend(ctx, f, params) return backend(modelPath, params)
} }
return nil, fmt.Errorf("unsupported backend") return nil, fmt.Errorf("unsupported backend")

View File

@ -44,8 +44,15 @@ func devices() []*C.struct_ggml_backend_device {
} }
type Backend struct { type Backend struct {
// modelPath is the location of the model data
modelPath string
meta *fsggml.GGML meta *fsggml.GGML
// tensorLoadTargets maps from the name of the tensor in the file
// to the name that is used by the model definition
tensorLoadTargets map[string][]string
sched *C.struct_ggml_backend_sched sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type schedBufts []*C.struct_ggml_backend_buffer_type
@ -64,8 +71,14 @@ type Backend struct {
maxGraphNodes int maxGraphNodes int
} }
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fsggml.Decode(r, -1) r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -307,73 +320,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
} }
} }
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(r.Name())
if err != nil {
slog.Warn("file open error", "file", r.Name(), "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", r.Name(), "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
// map devices to backend buffer types so new tensors can be assigned to the correct device // map devices to backend buffer types so new tensors can be assigned to the correct device
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type) deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
@ -397,9 +343,11 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
return &Backend{ return &Backend{
flashAttention: params.FlashAttention, modelPath: modelPath,
meta: meta, flashAttention: params.FlashAttention,
tensors: tensors, meta: meta,
tensorLoadTargets: targets,
tensors: tensors,
sched: C.ggml_backend_sched_new( sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
@ -426,6 +374,77 @@ func init() {
ml.RegisterBackend("ggml", New) ml.RegisterBackend("ggml", New)
} }
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
var doneBytes atomic.Uint64
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range b.meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
for i := range tts {
target := b.tensorLoadTargets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := b.tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(b.modelPath)
if err != nil {
slog.Warn("file open error", "file", b.modelPath, "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", b.modelPath, "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if progress != nil {
done := doneBytes.Add(uint64(n))
progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
return nil
}
func (b *Backend) Config() fs.Config { func (b *Backend) Config() fs.Config {
return b.meta.KV() return b.meta.KV()
} }

View File

@ -98,14 +98,8 @@ func Register(name string, f func(fs.Config) (Model, error)) {
} }
// New initializes a new model instance with the provided configuration based on the metadata in the model file // New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) { func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath) b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +128,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
meta, _, err := fsggml.Decode(r, -1) meta, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -845,7 +845,7 @@ func (s *Server) loadModel(
multiUserCache bool, multiUserCache bool,
) { ) {
var err error var err error
s.model, err = model.New(ctx, mpath, params) s.model, err = model.New(mpath, params)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -874,6 +874,14 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
err = s.model.Backend().Load(ctx,
func(progress float32) {
s.progress = progress
})
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady s.status = llm.ServerStatusReady
s.ready.Done() s.ready.Done()
} }
@ -928,9 +936,6 @@ func Execute(args []string) error {
} }
params := ml.BackendParams{ params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads, NumThreads: *threads,
NumGPULayers: *numGPULayers, NumGPULayers: *numGPULayers,
MainGPU: *mainGPU, MainGPU: *mainGPU,

View File

@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
} }
defer bin.Close() defer bin.Close()
f, _, err := ggml.Decode(bin, -1) f, err := ggml.Decode(bin, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err return nil, err
} }
f, _, err := ggml.Decode(temp, 1024) f, err := ggml.Decode(temp, 1024)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err return nil, err
@ -508,7 +508,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var offset int64 var offset int64
for offset < stat.Size() { for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 1024) f, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
@ -523,7 +523,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
} }
var layer Layer var layer Layer
if digest != "" && n == stat.Size() && offset == 0 { if digest != "" && f.Length == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil { if err != nil {
slog.Debug("could not create new layer from layer", "error", err) slog.Debug("could not create new layer from layer", "error", err)
@ -533,14 +533,14 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" { if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype) layer, err = NewLayer(io.NewSectionReader(blob, offset, f.Length), mediatype)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
layers = append(layers, &layerGGML{layer, f}) layers = append(layers, &layerGGML{layer, f})
offset = n offset = f.Length
} }
return detectChatTemplate(layers) return detectChatTemplate(layers)

View File

@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil { if err == nil {
defer r.Close() defer r.Close()
f, _, err := ggml.Decode(r, 1024) f, err := ggml.Decode(r, 1024)
if err == nil { if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding) capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
defer blob.Close() defer blob.Close()
f, _, err := ggml.Decode(blob, -1) f, err := ggml.Decode(blob, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
defer fp.Close() defer fp.Close()
meta, _, err := fsggml.Decode(fp, -1) meta, err := fsggml.Decode(fp, -1)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
} }
defer fpNew.Close() defer fpNew.Close()
newMeta, _, err := fsggml.Decode(fpNew, -1) newMeta, err := fsggml.Decode(fpNew, -1)
if err != nil { if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
} }