diff --git a/llama/patches/0032-interleave-multi-rope.patch b/llama/patches/0032-interleave-multi-rope.patch new file mode 100644 index 00000000..eb41639e --- /dev/null +++ b/llama/patches/0032-interleave-multi-rope.patch @@ -0,0 +1,113 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Web, 16 Oct 2025 20:37:19 -0700 +Subject: [PATCH] interleave multi rope + +since ollama doesn't use mrope for anything else, change it to mean the +interleaved version used for qwen3vl +--- + ggml/src/ggml-cpu/ops.cpp | 7 ++----- + ggml/src/ggml-cuda/rope.cu | 12 +++--------- + ggml/src/ggml-metal/ggml-metal.metal | 10 +++------- + ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp | 12 +++--------- + 4 files changed, 11 insertions(+), 30 deletions(-) + +diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp +index 31478dd8e..4d1ed207e 100644 +--- a/ggml/src/ggml-cpu/ops.cpp ++++ b/ggml/src/ggml-cpu/ops.cpp +@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init( + } + + float theta = theta_t; +- if (sector >= sections[0] && sector < sec_w) { ++ if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) { + theta = theta_h; + } +- else if (sector >= sec_w && sector < sec_w + sections[2]) { ++ else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) { + theta = theta_w; + } +- else if (sector >= sec_w + sections[2]) { +- theta = theta_e; +- } + + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] +diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu +index d058504cd..287fe9d2c 100644 +--- a/ggml/src/ggml-cuda/rope.cu ++++ b/ggml/src/ggml-cuda/rope.cu +@@ -151,19 +151,13 @@ static __global__ void rope_multi( + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + +- float theta_base = 0.0; +- if (sector < sections.v[0]) { +- theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); +- } +- else if (sector >= sections.v[0] && sector < sec_w) { ++ float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); ++ if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } +- else if (sector >= sec_w && sector < sec_w + sections.v[2]) { ++ else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } +- else if (sector >= sec_w + sections.v[2]) { +- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); +- } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index 375a0c7fd..9866c96b4 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi( + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + +- float theta_base; +- if (sector < args.sect_0) { +- theta_base = (float) pos[i2]; +- } else if (sector < sec_w01) { ++ float theta_base = (float) pos[i2]; ++ if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { + theta_base = (float) pos[i2 + args.ne02]; +- } else if (sector < sec_w012) { ++ } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) { + theta_base = (float) pos[i2 + args.ne02 * 2]; +- } else { +- theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +index 111286b49..6fc2b42f8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +@@ -31,19 +31,13 @@ void main() { + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + +- float theta_base = 0.0; +- if (sector < p.sections[0]) { +- theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); +- } +- else if (sector >= p.sections[0] && sector < sec_w) { ++ float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); ++ if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } +- else if (sector >= sec_w && sector < sec_w + p.sections[2]) { ++ else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } +- else if (sector >= sec_w + p.sections[2]) { +- theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); +- } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 38b18b3e..8c782d73 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -11,6 +11,7 @@ package ggml import "C" import ( + "cmp" "context" "encoding/binary" "errors" @@ -1490,14 +1491,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor { // Default options - opts := rope.Options{ - Factors: &Tensor{}, - OriginalContextLength: 131072, - ExtrapolationFactor: 0., - AttentionFactor: 1., - BetaFast: 32., - BetaSlow: 1., - } + opts := rope.Options{Factors: &Tensor{}} // Apply any provided options for _, option := range options { @@ -1509,24 +1503,44 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } - return &Tensor{ - b: t.b, - t: C.ggml_rope_ext( + var tt *C.struct_ggml_tensor + if len(opts.MRoPE.Sections) > 0 { + mropeSections := make([]C.int32_t, 4) + for i, section := range opts.MRoPE.Sections { + mropeSections[i] = C.int32_t(section) + } + + tt = C.ggml_rope_multi( ctx.(*Context).ctx, dequant, positions.(*Tensor).t, opts.Factors.(*Tensor).t, C.int(ropeDim), + unsafe.SliceData(mropeSections), C.int(opts.Type), - C.int(opts.OriginalContextLength), - C.float(ropeBase), - C.float(ropeScale), - C.float(opts.ExtrapolationFactor), - C.float(opts.AttentionFactor), - C.float(opts.BetaFast), - C.float(opts.BetaSlow), - ), + cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), + C.float(ropeBase), C.float(ropeScale), + C.float(opts.YaRN.ExtrapolationFactor), + cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), + cmp.Or(C.float(opts.YaRN.BetaFast), 32), + cmp.Or(C.float(opts.YaRN.BetaSlow), 1), + ) + } else { + tt = C.ggml_rope_ext( + ctx.(*Context).ctx, + dequant, + positions.(*Tensor).t, + opts.Factors.(*Tensor).t, + C.int(ropeDim), C.int(opts.Type), + cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), + C.float(ropeBase), C.float(ropeScale), + C.float(opts.YaRN.ExtrapolationFactor), + cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), + cmp.Or(C.float(opts.YaRN.BetaFast), 32), + cmp.Or(C.float(opts.YaRN.BetaSlow), 1), + ) } + return &Tensor{b: t.b, t: tt} } func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index 31478dd8..4d1ed207 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init( } float theta = theta_t; - if (sector >= sections[0] && sector < sec_w) { + if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) { theta = theta_h; } - else if (sector >= sec_w && sector < sec_w + sections[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) { theta = theta_w; } - else if (sector >= sec_w + sections[2]) { - theta = theta_e; - } rope_yarn( theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu index d058504c..287fe9d2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu @@ -151,19 +151,13 @@ static __global__ void rope_multi( const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { + float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); - } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 9c0e0c56..f342872d 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -6523,15 +6523,11 @@ kernel void kernel_rope_multi( const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 const int sector = ic % sect_dims; - float theta_base; - if (sector < args.sect_0) { - theta_base = (float) pos[i2]; - } else if (sector < sec_w01) { + float theta_base = (float) pos[i2]; + if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { theta_base = (float) pos[i2 + args.ne02]; - } else if (sector < sec_w012) { + } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) { theta_base = (float) pos[i2 + args.ne02 * 2]; - } else { - theta_base = (float) pos[i2 + args.ne02 * 3]; } // end of mrope diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 375a0c7f..9866c96b 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi( const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 const int sector = ic % sect_dims; - float theta_base; - if (sector < args.sect_0) { - theta_base = (float) pos[i2]; - } else if (sector < sec_w01) { + float theta_base = (float) pos[i2]; + if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { theta_base = (float) pos[i2 + args.ne02]; - } else if (sector < sec_w012) { + } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) { theta_base = (float) pos[i2 + args.ne02 * 2]; - } else { - theta_base = (float) pos[i2 + args.ne02 * 3]; } // end of mrope diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 111286b4..633dc20f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -31,19 +31,13 @@ void main() { const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { + float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) { theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) { theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/rope.go index 57dd2252..bca8058d 100644 --- a/ml/nn/rope/rope.go +++ b/ml/nn/rope/rope.go @@ -4,21 +4,21 @@ import "github.com/ollama/ollama/ml" // Options contains optional parameters for RoPE function type Options struct { - Type int - Factors ml.Tensor - OriginalContextLength int + Type int + Factors ml.Tensor // YaRN options - ExtrapolationFactor, - AttentionFactor, - BetaFast, - BetaSlow float32 -} + YaRN struct { + OriginalContextLength int + ExtrapolationFactor, + AttentionFactor, + BetaFast, + BetaSlow float32 + } -// WithOriginalContextLength sets a custom context length -func WithOriginalContextLength(n int) func(*Options) { - return func(opts *Options) { - opts.OriginalContextLength = n + // MRoPE options + MRoPE struct { + Sections []int } } @@ -38,14 +38,28 @@ func WithFactors(factors ml.Tensor) func(*Options) { } } +// WithOriginalContextLength sets a custom context length +func WithOriginalContextLength(n int) func(*Options) { + return func(opts *Options) { + opts.YaRN.OriginalContextLength = n + } +} + func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) { return func(opts *Options) { - opts.ExtrapolationFactor = extrapolationFactor + opts.YaRN.ExtrapolationFactor = extrapolationFactor } } func WithAttentionFactor(attentionFactor float32) func(*Options) { return func(opts *Options) { - opts.AttentionFactor = attentionFactor + opts.YaRN.AttentionFactor = attentionFactor + } +} + +func WithMRoPESections(sections []int) func(*Options) { + return func(opts *Options) { + opts.Type |= 1 << 3 + opts.MRoPE.Sections = sections } } diff --git a/model/models/qwen3vl/model.go b/model/models/qwen3vl/model.go index 08beb37c..579863ae 100644 --- a/model/models/qwen3vl/model.go +++ b/model/models/qwen3vl/model.go @@ -112,7 +112,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positionSlice := slices.Collect(makeSlice2D[int32](3, len(batch.Positions))) + // ggml mrope requires 4 positions per token: [time, height, width, extra] + positionSlice := slices.Collect(makeSlice2D[int32](4, len(batch.Positions))) for i, id := range batch.Positions { if id < int32(len(m.positionCache)) { id = m.positionCache[id] @@ -123,6 +124,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positionSlice[0][i] = id positionSlice[1][i] = id positionSlice[2][i] = id + // positionSlice[3] is intentionally left as zeros } hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx) @@ -147,8 +149,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } } - positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0]), len(positionSlice)) - cos, sin := m.rotaryEmbedding(ctx, positions) + positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice)) for i, layer := range m.TextModel.Layers { if m.Cache != nil { m.Cache.SetLayer(i) @@ -159,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { outputs = batch.Outputs } - hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, outputs, m.Cache, m.Options) + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) if i < len(deepstackVisualEmbeds) { hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i]) } @@ -191,9 +192,10 @@ func New(c fs.Config) (model.Model, error) { ImageProcessor: newImageProcessor(c), } - m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, position ml.Tensor) (ml.Tensor, error) { + m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) { m.positionCache = nil - return nil, kvcache.ErrNotSupported + positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1) + return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil }) return &m, nil } diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go index 14e7d7dc..f5767f65 100644 --- a/model/models/qwen3vl/model_text.go +++ b/model/models/qwen3vl/model_text.go @@ -10,6 +10,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" ) @@ -27,14 +29,18 @@ type TextOptions struct { numExperts, numExpertsUsed int normTopKProb bool - - inverseFrequenciesCache []float32 } func (o TextOptions) headDim() int { return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) } +func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { + return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))), + rope.WithMRoPESections(o.mropeSections), + ) +} + type TextAttention struct { Query *nn.Linear `gguf:"attn_q"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` @@ -44,7 +50,7 @@ type TextAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenStates.Dim(1) query := sa.Query.Forward(ctx, hiddenStates) @@ -58,8 +64,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tenso query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) + key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -125,10 +131,10 @@ type TextLayer struct { TextMLP } -func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { residual := hiddenStates hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, cos, sin, cache, opts) + hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, positions, cache, opts) if outputs != nil { hiddenStates = hiddenStates.Rows(ctx, outputs) @@ -153,42 +159,6 @@ type TextModel struct { Options *TextOptions } -func (m *TextModel) rotaryEmbedding(ctx ml.Context, positions ml.Tensor) (_, _ ml.Tensor) { - positions = positions.Reshape(ctx, 1, positions.Dim(0), positions.Dim(1)) - if len(m.Options.inverseFrequenciesCache) == 0 { - m.Options.inverseFrequenciesCache = make([]float32, m.Options.headDim()/2) - for i := range m.Options.inverseFrequenciesCache { - frequency := float32(math.Pow(float64(m.Options.ropeBase), float64(i*2)/float64(m.Options.headDim()))) - m.Options.inverseFrequenciesCache[i] = 1 / frequency - } - } - - inverseFrequencies := ctx.Input().FromFloats(m.Options.inverseFrequenciesCache, 1, len(m.Options.inverseFrequenciesCache)) - - positions = positions.Cast(ctx, ml.DTypeF32) - frequencies := inverseFrequencies.Mulmat(ctx, positions) - - interleaved := frequencies.View(ctx, - 0, frequencies.Dim(0), - frequencies.Stride(1), frequencies.Dim(1), - ) - - for _, i := range []int{1, 2} { - args := []int{ - i * frequencies.Stride(0), 1, - 3 * frequencies.Stride(0), m.Options.mropeSections[i], - frequencies.Stride(1), frequencies.Dim(1), - } - - ctx.Forward(frequencies.View(ctx, i*frequencies.Stride(2)+args[0], args[1:]...). - Copy(ctx, interleaved.View(ctx, args[0], args[1:]...))) - } - - interleaved = interleaved.Concat(ctx, interleaved, 0) - interleaved = interleaved.Reshape(ctx, interleaved.Dim(0), 1, interleaved.Dim(1), interleaved.Dim(2)) - return interleaved.Cos(ctx), interleaved.Sin(ctx) -} - var _ model.Model = (*Model)(nil) func newTextModel(c fs.Config) *TextModel {