convert: convert bf16 vision weights to fp16 (#12324)

This change moves back to converting bf16 vision weights to fp16,
specifically if they start with the name "v." (such as v.blk.0.attn_k.weight).

This fixes a bug where converted images are failing because they are trying
to call `im2col` which doesn't have a bf16 kernel in ggml.
This commit is contained in:
Patrick Devine 2025-09-17 17:43:17 -07:00 committed by GitHub
parent 9b8187b487
commit 2717dce6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 1 deletions

View File

@ -96,7 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if st.dtype == "BF16" && kind != tensorKindFP32 {
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
})
}
}
func TestSafetensorKind(t *testing.T) {
tests := []struct {
name string
st safetensor
expected uint32
}{
{
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindBF16,
},
{
name: "BF16 dtype with v. prefix should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "v.weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindFP16,
},
{
name: "BF16 dtype with FP32 base kind should return FP32",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10}, // will default to FP32
},
dtype: "BF16",
},
expected: tensorKindFP32,
},
{
name: "Non-BF16 dtype should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "FP16",
},
expected: tensorKindFP16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.st.Kind()
if result != tt.expected {
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
}
})
}
}