ollama/ml/nn/pooling/pooling.go
Michael Yang c253433d68
embed: cleanup (#12299)
* cleanup

* use pooling.TypeNone

* pooling test
2025-09-16 09:48:42 -07:00

43 lines
824 B
Go

package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
return hiddenStates
default:
panic("unknown pooling type")
}
}