gemma3: make embedding non-causal (#12297)

This commit is contained in:
Michael Yang 2025-10-27 19:54:08 -07:00 committed by GitHub
parent 5d347f6d6f
commit ec9eb28f4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 15 deletions

View File

@ -2,7 +2,6 @@ package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
@ -53,10 +52,5 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@ -182,16 +182,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if cache != nil {
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}
var lastLayerOutputs ml.Tensor