diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index c2d5a633..6e4c9640 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor { for _, t := range ts { if strings.Contains(t.Name(), "patch_embed.proj") { for t := range splitDim(t, 2, - strings.NewReplacer("patch_embed.proj", "patch_embd_0"), - strings.NewReplacer("patch_embed.proj", "patch_embd_1"), + split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")}, + split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")}, ) { t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 }) out = append(out, t) } } else if strings.Contains(t.Name(), "attn.qkv") { out = append(out, slices.Collect(splitDim(t, 0, - strings.NewReplacer("attn.qkv", "attn_q"), - strings.NewReplacer("attn.qkv", "attn_k"), - strings.NewReplacer("attn.qkv", "attn_v"), + split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")}, + split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")}, + split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")}, ))...) } else { out = append(out, &ggml.Tensor{ diff --git a/convert/tensor.go b/convert/tensor.go index ffb22ead..9d6919e3 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -1,53 +1,73 @@ package convert import ( + "cmp" "iter" "slices" "strings" - "github.com/ollama/ollama/fs/ggml" "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" ) +type split struct { + *strings.Replacer + dim int + + // fn is an optional function to apply to the tensor after slicing + fn func(tensor.Tensor) (tensor.Tensor, error) +} + // splitDim splits a tensor along a specified dimension into multiple tensors. The dimension -// is split evenly based on the number of replacers provided. -func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] { +// is split evenly based on the number of replacers provided unless a specific count is given. +func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { return func(yield func(*ggml.Tensor) bool) { - for i, replacer := range replacers { + var offset int + for _, split := range splits { + t := t.Clone() shape := slices.Clone(t.Shape()) - shape[dim] = shape[dim] / uint64(len(replacers)) + shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits))) slice := slices.Repeat([]tensor.Slice{nil}, len(shape)) - slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim])) + slice[dim] = tensor.S(offset, offset+int(shape[dim])) + offset += int(shape[dim]) - tt := t.Clone() - tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { + t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { dims := make([]int, len(shape)) for i := range shape { dims[i] = int(shape[i]) } - var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) - t, err := t.Slice(slice...) + var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + tt, err := tt.Slice(slice...) if err != nil { return nil, err } - t = tensor.Materialize(t) + tt = tensor.Materialize(tt) + + if split.fn != nil { + tt, err = split.fn(tt) + if err != nil { + return nil, err + } + } + // flatten tensor so it can be written as a vector - if err := t.Reshape(t.Shape().TotalSize()); err != nil { + if err := tt.Reshape(tt.Shape().TotalSize()); err != nil { return nil, err } - return native.VectorF32(t.(*tensor.Dense)) + return native.VectorF32(tt.(*tensor.Dense)) }) if !yield(&ggml.Tensor{ - Name: replacer.Replace(t.Name()), + Name: split.Replace(t.Name()), Kind: t.Kind(), Shape: shape, - WriterTo: tt, + WriterTo: t, }) { break } diff --git a/convert/tensor_test.go b/convert/tensor_test.go new file mode 100644 index 00000000..ea12d0f5 --- /dev/null +++ b/convert/tensor_test.go @@ -0,0 +1,304 @@ +package convert + +import ( + "bytes" + "encoding/binary" + "io" + "iter" + "slices" + "strings" + "testing" + + "github.com/pdevine/tensor" +) + +type fakeTensor struct { + name string + shape []uint64 + data []float32 + + repacker Repacker +} + +func (f fakeTensor) Name() string { + return f.name +} + +func (f fakeTensor) Shape() []uint64 { + return f.shape +} + +func (f fakeTensor) Kind() uint32 { + return 0 +} + +func (f *fakeTensor) SetRepacker(fn Repacker) { + f.repacker = fn +} + +func (f fakeTensor) Clone() Tensor { + return &fakeTensor{ + name: f.name, + shape: slices.Clone(f.shape), + data: slices.Clone(f.data), + repacker: f.repacker, + } +} + +func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) { + data := f.data + if f.repacker != nil { + data, err = f.repacker(f.name, data, f.shape) + if err != nil { + return 0, err + } + } + + if err := binary.Write(w, binary.LittleEndian, data); err != nil { + return 0, err + } + + return int64(len(data) * 4), nil +} + +func mul(shape []uint64) int { + n := 1 + for _, dim := range shape { + n *= int(dim) + } + return n +} + +func TestSplitDim(t *testing.T) { + r := fakeTensor{ + name: "a.b", + shape: []uint64{3, 4}, + data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + } + + t.Run("no split", func(t *testing.T) { + for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) { + if tt.Name != "x.b" { + t.Fatalf("expected name 'x', got '%s'", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 4}) { + t.Fatalf("expected shape [3, 4], got %v", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) { + t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s) + } + } + }) + + t.Run("even split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y")}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { + t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) { + t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s) + } + } + }) + + t.Run("uneven split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{2, 4}) { + t.Fatal("expected shape [2, 4], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) { + t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{1, 4}) { + t.Fatal("expected shape [1, 4], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{8, 9, 10, 11}) { + t.Fatal("expected data [8, 9, 10, 11], got", f32s) + } + } + }) + + t.Run("split with transpose", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) { + return tensor.Transpose(tt, 1, 0) + }}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { + t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) { + t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s) + } + } + }) +}