kvcache: Remove special case for reservation mask

We currently short circuit generation of the cache mask and just
generate an empty tensor of the correct size. However, in some
cases, this can also skip a cast operation. This can result in the
worst case graph being not fully worst case.

We don't actually need the fast path for mask generation, so it's
better to just use the normal code path.
This commit is contained in:
Jesse Gross 2025-10-22 16:00:43 -07:00 committed by Jesse Gross
parent a8d9c2648e
commit 1c093e97af

View File

@ -40,11 +40,6 @@ type Causal struct {
// ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put
curLayer int
@ -206,13 +201,12 @@ func (c *Causal) Close() {
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !c.curReserve {
if !reserve {
c.updateSlidingWindow()
var err error
@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {