mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
kvcache: Remove special case for reservation mask
We currently short circuit generation of the cache mask and just generate an empty tensor of the correct size. However, in some cases, this can also skip a cast operation. This can result in the worst case graph being not fully worst case. We don't actually need the fast path for mask generation, so it's better to just use the normal code path.
This commit is contained in:
parent
a8d9c2648e
commit
1c093e97af
|
|
@ -40,11 +40,6 @@ type Causal struct {
|
|||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
|
|
@ -206,13 +201,12 @@ func (c *Causal) Close() {
|
|||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !c.curReserve {
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
|
|
@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user