fix mllama conversion (#10716)

cross attention Q and K projections needs to have their heads swapped, similar to non-cross attention Q and K tensors
This commit is contained in:
Michael Yang 2025-05-15 12:15:01 -07:00 committed by GitHub
parent bd68d3ae50
commit 55760195e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -139,7 +139,8 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
} }
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") { if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") ||
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
if !p.skipRepack { if !p.skipRepack {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
@ -181,9 +182,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
} }
var heads uint32 var heads uint32
if strings.HasSuffix(name, "attn_q.weight") { if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") {
heads = p.NumAttentionHeads heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") { } else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)