From 1c093e97af54b3d78d54426ea5d05ef8c4e83ca0 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 22 Oct 2025 16:00:43 -0700 Subject: [PATCH] kvcache: Remove special case for reservation mask We currently short circuit generation of the cache mask and just generate an empty tensor of the correct size. However, in some cases, this can also skip a cast operation. This can result in the worst case graph being not fully worst case. We don't actually need the fast path for mask generation, so it's better to just use the normal code path. --- kvcache/causal.go | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 543a65a6..c7b3595e 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -40,11 +40,6 @@ type Causal struct { // ** current forward pass ** - // curReserve indicates that this forward pass is only for - // memory reservation and we should not update our metadata - // based on it. - curReserve bool - // the active layer for Get and Put curLayer int @@ -206,13 +201,12 @@ func (c *Causal) Close() { } func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { - c.curReserve = reserve c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - if !c.curReserve { + if !reserve { c.updateSlidingWindow() var err error @@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { length := c.curCellRange.max - c.curCellRange.min + 1 - if c.curReserve { - return ctx.Input().Empty(c.config.MaskDType, length, batchSize) - } - mask := make([]float32, batchSize*length) for i := range c.curBatchSize {