mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
43 lines
824 B
Go
43 lines
824 B
Go
package pooling
|
|
|
|
import (
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
type Type uint32
|
|
|
|
const (
|
|
TypeNone Type = iota
|
|
TypeMean
|
|
TypeCLS
|
|
TypeLast
|
|
)
|
|
|
|
func (t Type) String() string {
|
|
switch t {
|
|
case TypeMean:
|
|
return "Mean"
|
|
case TypeCLS:
|
|
return "CLS"
|
|
case TypeLast:
|
|
return "Last"
|
|
default:
|
|
return "Unknown"
|
|
}
|
|
}
|
|
|
|
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
|
switch t {
|
|
case TypeMean:
|
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
case TypeCLS:
|
|
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
|
case TypeLast:
|
|
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
|
return hiddenStates
|
|
default:
|
|
panic("unknown pooling type")
|
|
}
|
|
}
|