From ec9eb28f4c3481d58c6da38ee488cb8cd5379256 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 27 Oct 2025 19:54:08 -0700 Subject: [PATCH] gemma3: make embedding non-causal (#12297) --- model/models/gemma3/embed.go | 6 ------ model/models/gemma3/model_text.go | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 52554776..9251111c 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -2,7 +2,6 @@ package gemma3 import ( "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/pooling" @@ -53,10 +52,5 @@ func newEmbedModel(c fs.Config) (model.Model, error) { poolingType: pooling.Type(c.Uint("pooling_type", 0)), } - m.Cache = kvcache.NewWrapperCache( - kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), - kvcache.NewCausalCache(m.Shift), - ) - return m, nil } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 631baecc..d5bdd410 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -182,16 +182,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) // kv cache every 6 layers - cacheType := cacheTypeSWA - if (i+1)%gemmaGlobalCacheCount == 0 { - cacheType = cacheTypeCausal - } - cache.SetLayer(i) - wc := cache.(*kvcache.WrapperCache) - wc.SetLayerType(cacheType) + if cache != nil { + cacheType := cacheTypeSWA + if (i+1)%gemmaGlobalCacheCount == 0 { + cacheType = cacheTypeCausal + } + cache.SetLayer(i) + wc := cache.(*kvcache.WrapperCache) + wc.SetLayerType(cacheType) - if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { - causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) + if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { + causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) + } } var lastLayerOutputs ml.Tensor