diff --git a/convert/convert_llama.go b/convert/convert_llama.go index e491a9d8..43969749 100644 --- a/convert/convert_llama.go +++ b/convert/convert_llama.go @@ -139,7 +139,8 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor { } for _, t := range ts { - if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") { + if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") || + strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") { if !p.skipRepack { t.SetRepacker(p.repack) } @@ -181,9 +182,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa } var heads uint32 - if strings.HasSuffix(name, "attn_q.weight") { + if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") { heads = p.NumAttentionHeads - } else if strings.HasSuffix(name, "attn_k.weight") { + } else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") { heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) } else { return nil, fmt.Errorf("unknown tensor for repack: %s", name)