mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
llm: Allow overriding flash attention setting
As we automatically enable flash attention for more models, there are likely some cases where we get it wrong. This allows setting OLLAMA_FLASH_ATTENTION=0 to disable it, even for models that usually have flash attention.
This commit is contained in:
parent
05a43e078a
commit
fdb109469f
|
|
@ -145,8 +145,8 @@ func Remotes() []string {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func Bool(k string) func() bool {
|
func BoolWithDefault(k string) func(defaultValue bool) bool {
|
||||||
return func() bool {
|
return func(defaultValue bool) bool {
|
||||||
if s := Var(k); s != "" {
|
if s := Var(k); s != "" {
|
||||||
b, err := strconv.ParseBool(s)
|
b, err := strconv.ParseBool(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -156,7 +156,14 @@ func Bool(k string) func() bool {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Bool(k string) func() bool {
|
||||||
|
withDefault := BoolWithDefault(k)
|
||||||
|
return func() bool {
|
||||||
|
return withDefault(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -177,7 +184,7 @@ func LogLevel() slog.Level {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// FlashAttention enables the experimental flash attention feature.
|
// FlashAttention enables the experimental flash attention feature.
|
||||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||||
// KvCacheType is the quantization type for the K/V cache.
|
// KvCacheType is the quantization type for the K/V cache.
|
||||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||||
// NoHistory disables readline history.
|
// NoHistory disables readline history.
|
||||||
|
|
@ -263,7 +270,7 @@ type EnvVar struct {
|
||||||
func AsMap() map[string]EnvVar {
|
func AsMap() map[string]EnvVar {
|
||||||
ret := map[string]EnvVar{
|
ret := map[string]EnvVar{
|
||||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
|
|
|
||||||
|
|
@ -195,7 +195,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
slog.Warn("model missing blk.0 layer size")
|
slog.Warn("model missing blk.0 layer size")
|
||||||
}
|
}
|
||||||
|
|
||||||
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
|
||||||
(discover.GpuInfoList)(gpus).FlashAttentionSupported() &&
|
(discover.GpuInfoList)(gpus).FlashAttentionSupported() &&
|
||||||
f.SupportsFlashAttention()
|
f.SupportsFlashAttention()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -196,14 +196,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||||
loadRequest.ProjectorPath = projectors[0]
|
loadRequest.ProjectorPath = projectors[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fa := envconfig.FlashAttention(f.FlashAttention())
|
||||||
|
|
||||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||||
// that can handle it.
|
// that can handle it.
|
||||||
fa := envconfig.FlashAttention()
|
|
||||||
if f.FlashAttention() {
|
|
||||||
slog.Info("model wants flash attention")
|
|
||||||
fa = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if fa && !gpus.FlashAttentionSupported() {
|
if fa && !gpus.FlashAttentionSupported() {
|
||||||
slog.Warn("flash attention enabled but not supported by gpu")
|
slog.Warn("flash attention enabled but not supported by gpu")
|
||||||
fa = false
|
fa = false
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user