fix(qwen3): deepseek distill

deepseek's qwen3 distill uses a different rope scheme so support both
This commit is contained in:
Michael Yang 2025-10-13 12:09:53 -07:00 committed by Michael Yang
parent 6544e14735
commit 6c833d5f8d

View File

@ -15,11 +15,17 @@ import (
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
eps float32
ropeBase, ropeScale float32
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
keyLength, valueLength int
eps,
ropeBase,
ropeScale float32
ropeType string
originalContextLength int
numExperts, numExpertsUsed int
normTopKProb bool
@ -29,6 +35,19 @@ func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
@ -52,8 +71,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@ -183,7 +202,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
@ -216,17 +235,19 @@ func New(c fs.Config) (model.Model, error) {
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
},
}