diff --git a/server/model.go b/server/model.go index 6b5439a4..401547e4 100644 --- a/server/model.go +++ b/server/model.go @@ -10,9 +10,6 @@ import ( "log/slog" "net/http" "os" - "slices" - "strings" - "text/template/parse" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" @@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } - -func parseObjects(s string) []map[string]any { - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { - // skip over any syntax errors - offset += int(syntax.Offset) - } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { - // skip over any unmarshalable types - offset += int(unmarshalType.Offset) - } else if err != nil { - return nil - } else { - offset += int(decoder.InputOffset()) - objs = append(objs, obj) - } - } - - return objs -} - -// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// mxyng: this only really works if the input contains tool calls in some JSON format -func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { - // create a subtree from the node that ranges over .ToolCalls - tmpl := m.Template.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl == nil { - return nil, false - } - - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return nil, false - } - - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false - } - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - } - - return all - } - - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) - } - - var toolCalls []api.ToolCall - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } - } - - return toolCalls, len(toolCalls) > 0 -} diff --git a/server/model_test.go b/server/model_test.go deleted file mode 100644 index e5c2f2bb..00000000 --- a/server/model_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" -) - -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) - if err != nil { - t.Fatal(err) - } - - return bytes.NewBuffer(bts) -} - -func TestExecuteWithTools(t *testing.T) { - p := filepath.Join("testdata", "tools") - cases := []struct { - model string - output string - ok bool - }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, - {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"command-r-plus", "Action: ```json" + ` -[ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } -] -` + "```", true}, - {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"llama3-groq-tool-use", ` -{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} -`, true}, - {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, - {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, - } - - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) - } - - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) - } - - calls := []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, - }, - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, - }, - } - - for _, tt := range cases { - t.Run(tt.model, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) - if err != nil { - t.Fatal(err) - } - - t.Run("template", func(t *testing.T) { - var actual bytes.Buffer - if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - m := &Model{Template: tmpl} - actual, ok := m.parseToolCalls(tt.output) - if ok != tt.ok { - t.Fatalf("expected %t, got %t", tt.ok, ok) - } - - if tt.ok { - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } - }) - }) - } -} - -func TestParseObjects(t *testing.T) { - tests := []struct { - input string - want []map[string]any - }{ - { - input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": `, - want: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := parseObjects(tc.input) - - if diff := cmp.Diff(got, tc.want); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - } -} diff --git a/server/routes.go b/server/routes.go index d0b8f487..42e8cdd1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/tools" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -1482,11 +1483,20 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + var toolParser *tools.Parser + if len(req.Tools) > 0 { + toolParser, err = tools.NewParser(m.Template.Template) + if err != nil { + slog.Error("failed to create tool parser", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + ch := make(chan any) go func() { defer close(ch) - var sb strings.Builder - var toolCallIndex int = 0 + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1512,37 +1522,21 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO: tool call checking and filtering should be moved outside of this callback once streaming - // however this was a simple change for now without reworking streaming logic of this (and other) - // handlers - if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { - ch <- res - return - } - - // Streaming tool calls: - // If tools are recognized, use a flag to track the sending of a tool downstream - // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ + if len(req.Tools) > 0 { + toolCalls, content := toolParser.Add(r.Content) + if len(content) > 0 { + res.Message.Content = content + } else if len(toolCalls) > 0 { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + } else { + if r.Done { + ch <- res + } + return } - res.Message.Content = "" - sb.Reset() - ch <- res - return - } - - if r.Done { - // Send any remaining content if no tool calls were detected - if toolCallIndex == 0 { - res.Message.Content = sb.String() - } - ch <- res } + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1551,11 +1545,15 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var sb strings.Builder + var toolCalls []api.ToolCall for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + if len(req.Tools) > 0 { + toolCalls = append(toolCalls, t.Message.ToolCalls...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1571,12 +1569,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls } c.JSON(http.StatusOK, resp) diff --git a/server/testdata/tools/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl similarity index 100% rename from server/testdata/tools/command-r-plus.gotmpl rename to tools/testdata/command-r-plus.gotmpl diff --git a/server/testdata/tools/command-r-plus.out b/tools/testdata/command-r-plus.out similarity index 100% rename from server/testdata/tools/command-r-plus.out rename to tools/testdata/command-r-plus.out diff --git a/server/testdata/tools/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl similarity index 100% rename from server/testdata/tools/firefunction.gotmpl rename to tools/testdata/firefunction.gotmpl diff --git a/server/testdata/tools/firefunction.out b/tools/testdata/firefunction.out similarity index 100% rename from server/testdata/tools/firefunction.out rename to tools/testdata/firefunction.out diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.gotmpl rename to tools/testdata/llama3-groq-tool-use.gotmpl diff --git a/server/testdata/tools/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.out rename to tools/testdata/llama3-groq-tool-use.out diff --git a/tools/testdata/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl new file mode 100644 index 00000000..b132423e --- /dev/null +++ b/tools/testdata/llama3.2.gotmpl @@ -0,0 +1,44 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +{{ if .System }}{{ .System }} +{{- end }} +{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities. +{{- end }}<|eot_id|> +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 }} +{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> +{{- if and $.Tools $last }} + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{{ range $.Tools }} +{{- . }} +{{ end }} +{{ .Content }}<|eot_id|> +{{- else }} + +{{ .Content }}<|eot_id|> +{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> +{{- if .ToolCalls }} +{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} +{{- else }} + +{{ .Content }} +{{- end }}{{ if not $last }}<|eot_id|>{{ end }} +{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> + +{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3.2.out b/tools/testdata/llama3.2.out new file mode 100644 index 00000000..a27c6eaf --- /dev/null +++ b/tools/testdata/llama3.2.out @@ -0,0 +1,24 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> + +22<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + +What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/server/testdata/tools/messages.json b/tools/testdata/messages.json similarity index 100% rename from server/testdata/tools/messages.json rename to tools/testdata/messages.json diff --git a/server/testdata/tools/mistral.gotmpl b/tools/testdata/mistral.gotmpl similarity index 100% rename from server/testdata/tools/mistral.gotmpl rename to tools/testdata/mistral.gotmpl diff --git a/server/testdata/tools/mistral.out b/tools/testdata/mistral.out similarity index 100% rename from server/testdata/tools/mistral.out rename to tools/testdata/mistral.out diff --git a/server/testdata/tools/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl similarity index 100% rename from server/testdata/tools/nemotron.gotmpl rename to tools/testdata/nemotron.gotmpl diff --git a/server/testdata/tools/nemotron.out b/tools/testdata/nemotron.out similarity index 100% rename from server/testdata/tools/nemotron.out rename to tools/testdata/nemotron.out diff --git a/tools/testdata/qwen2.5.gotmpl b/tools/testdata/qwen2.5.gotmpl new file mode 100644 index 00000000..cbd7302c --- /dev/null +++ b/tools/testdata/qwen2.5.gotmpl @@ -0,0 +1,51 @@ +{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> +{{- else if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen2.5.out b/tools/testdata/qwen2.5.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen2.5.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/tools/testdata/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl new file mode 100644 index 00000000..26f6656f --- /dev/null +++ b/tools/testdata/qwen3.gotmpl @@ -0,0 +1,50 @@ +{{- if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen3.out b/tools/testdata/qwen3.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen3.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/server/testdata/tools/tools.json b/tools/testdata/tools.json similarity index 100% rename from server/testdata/tools/tools.json rename to tools/testdata/tools.json diff --git a/server/testdata/tools/xlam.gotmpl b/tools/testdata/xlam.gotmpl similarity index 100% rename from server/testdata/tools/xlam.gotmpl rename to tools/testdata/xlam.gotmpl diff --git a/server/testdata/tools/xlam.out b/tools/testdata/xlam.out similarity index 100% rename from server/testdata/tools/xlam.out rename to tools/testdata/xlam.out diff --git a/tools/tools.go b/tools/tools.go new file mode 100644 index 00000000..509ca90a --- /dev/null +++ b/tools/tools.go @@ -0,0 +1,271 @@ +package tools + +import ( + "encoding/json" + "errors" + "log/slog" + "strings" + gotmpl "text/template" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +var ( + errInvalidToolCall = errors.New("invalid tool call format") + errAccumulateMore = errors.New("need to accumulate more content") +) + +type Parser struct { + parseLeadingJSON bool + prefix string + prefixFound bool + tmpl gotmpl.Template + sb strings.Builder + index int + name string + arguments string + done bool +} + +// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// +// Parameters: +// - s: The string to parse +// - name: The field name from template that identifies the tool call name +// - arguments: The field name from template that identifies the tool call arguments +// +// Returns: +// - []api.ToolCall: The parsed tool calls if successful +// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful +func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { + // Check for balanced braces before attempting to parse + braceCount := 0 + squareCount := 0 + startIndex := -1 + var rawToolCalls []string + s = strings.TrimSpace(s) + + // Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. + trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") + for i, c := range s { + switch c { + case '{': + braceCount++ + if startIndex == -1 { + startIndex = i + } + case '}': + braceCount-- + if braceCount == 0 { + rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) + startIndex = -1 + } + case '[': + if trackSquareBrackets { + squareCount++ + } + case ']': + if trackSquareBrackets { + squareCount-- + } + } + + // Negative means we have an extra closing brace/bracket + if braceCount < 0 || squareCount < 0 { + return nil, errInvalidToolCall + } + } + + // If braces/brackets aren't balanced, need more input + if braceCount > 0 || squareCount > 0 { + return nil, errAccumulateMore + } + + t := strings.TrimSpace(s) + if len(t) == 0 { + return nil, errAccumulateMore + } + // If the input is a single square bracket, it's not a valid tool call + if t[0] == '[' && len(t) == 1 { + return nil, errAccumulateMore + } + + // Attempt full unmarshal of the JSON + var toolCalls []api.ToolCall + for _, rawToolCall := range rawToolCalls { + var resp map[string]any + if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { + continue + } + + // Collect nested objects that could contain tool calls + objs := collect(resp) + if len(objs) == 0 { + continue + } + + // Extract tool calls from objects + for _, kv := range objs { + n, nok := kv[name].(string) + a, aok := kv[arguments].(map[string]any) + if nok && aok { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: n, + Arguments: a, + }, + }) + } else { + slog.Debug("No valid tool call found in object.", "object", kv) + } + } + } + + // Valid JSON, no tool calls found + if len(toolCalls) == 0 { + slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) + return nil, errInvalidToolCall + } + + return toolCalls, nil +} + +// checkPrefix processes a string to find and handle a prefix pattern. +// +// Returns: +// - The processed string with prefix removed if found +// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful +func (p *Parser) checkPrefix(s string) (string, error) { + original := s + if strings.ContainsRune(s, '\n') { + s = strings.ReplaceAll(s, "\n", " ") + } + + if s == "" || p.prefix == "" { + return s, nil + } + + // Check for prefix at start of string + if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { + // Found prefix at start - accumulate for potential tool + p.prefixFound = true + return cut, nil + } + + // Check if prefix overlaps end of string + if idx := suffixOverlap(s, p.prefix); idx != -1 { + // Return everything except overlapping portion + p.sb.Reset() + p.sb.WriteString(s[idx:]) + return original[:idx], errAccumulateMore + } + + // Check if prefix appears in middle of string + if idx := strings.Index(s, p.prefix); idx != -1 { + // Save remainder starting at prefix for next pass + p.sb.Reset() + p.sb.WriteString(strings.TrimSpace(s[idx:])) + // Return everything before prefix + return original[:idx], errAccumulateMore + } + + // No partial prefix found + return s, nil +} + +// Add processes a string input to parse tool calls and content. +// It handles prefix detection and JSON parsing to extract tool calls. +// +// Returns: +// - tools: Any parsed tool calls +// - content: Non-tool call content +func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { + if strings.TrimSpace(s) == "" { + return nil, s + } + if p.done { + if p.index == 0 { + // Return original string if no tool calls found at start + return nil, s + } + // Return empty if no tool calls found after start + return nil, "" + } + p.sb.WriteString(s) + s = p.sb.String() + + // Check for prefix pattern in input + s, err := p.checkPrefix(s) + if err != nil { + // Need more input to complete prefix + return nil, s + } + + // Exit if prefix exists in template, greedy parsing is off, and prefix not found + if !p.parseLeadingJSON && !p.prefixFound { + p.sb.Reset() + return nil, s + } + + toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) + if err != nil { + if errors.Is(err, errAccumulateMore) { + return nil, "" + } + p.sb.Reset() + // Do not try parsing leading JSON if JSON not found + p.parseLeadingJSON = false + if p.prefix == "" { + p.done = true + } + if p.index != 0 && p.prefix == "" { + return nil, "" + } + if p.prefixFound { + // Drop tokens since prefix was found + return nil, "" + } + return nil, s + } + + for _, tc := range toolCalls { + tc.Function.Index = p.index + p.index++ + } + + p.sb.Reset() + return toolCalls, "" +} + +// NewParser creates a new tool call parser from a template. It extracts the tool call format, +// prefix, and field names from the template to use for parsing tool calls from model output. +// +// Returns an error if the template does not contain valid tool call formatting. +func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { + parsed, err := template.Parse(templateToProcess.Root.String()) + if err != nil { + return nil, err + } + + tt, err := toolTemplate(parsed) + if err != nil { + return nil, err + } + + tp := toolPrefix(templateToProcess) + + name, arguments, err := extractToolArgs(tt) + if err != nil { + return nil, err + } + + return &Parser{ + tmpl: *tt, + sb: strings.Builder{}, + prefix: tp, + parseLeadingJSON: true, + name: name, + arguments: arguments, + }, nil +} diff --git a/tools/tools_test.go b/tools/tools_test.go new file mode 100644 index 00000000..1ae3bff8 --- /dev/null +++ b/tools/tools_test.go @@ -0,0 +1,644 @@ +package tools + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +func readFile(t *testing.T, base, name string) *bytes.Buffer { + t.Helper() + + bts, err := os.ReadFile(filepath.Join(base, name)) + if err != nil { + t.Fatal(err) + } + + return bytes.NewBuffer(bts) +} + +func TestParseJSONToolCalls(t *testing.T) { + tests := []struct { + name string + input string + nameField string + argsField string + wantToolCalls []api.ToolCall + wantErr error + prefix string + }{ + { + name: "valid single tool call", + input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: map[string]any{ + "arg1": "value1", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "incomplete JSON", + input: `{"name": "test_tool", "arguments": {"arg1": `, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "invalid JSON", + input: `not json at all`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "missing required fields", + input: `{"other": "field"}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "multiple tool calls in array", + input: `[ + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + ]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls without array", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls with text after", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}} text + {"name": "tool2", "arguments": {"arg2": "value"}} text + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "second tool call in array", + input: ` + , {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + // a bad JSON would not return any tool calls or content as it would always accumulate more + { + name: "unbalanced square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "incomplete square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "nested arrays in arguments", + input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) + + if err != tt.wantErr { + t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) + } + + if len(gotCalls) != 0 && tt.wantErr != nil { + t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) + } + + if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { + t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParseToolCalls(t *testing.T) { + p := filepath.Join("testdata") + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + } + t2 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "celsius", + "location": "Toronto, Canada", + }, + }, + } + + cases := []struct { + name string + model string + output string + expectedToolCall []api.ToolCall + expectedTokens string + }{ + { + name: "mistral malformed json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral tool calls with text between no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral valid json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls with text between and prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, + expectedTokens: "", + }, + { + name: "mistral incomplete json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + }, + { + name: "mistral invalid tool call with explanatory text no prefix", + model: "mistral", + output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "command r plus tool calls with json block format", + model: "command-r-plus", + output: "Action: ```json" + ` + [ + { + "tool_name": "get_current_weather", + "parameters": { + "format": "fahrenheit", + "location": "San Francisco, CA" + } + }, + { + "tool_name": "get_current_weather", + "parameters": { + "format": "celsius", + "location": "Toronto, Canada" + } + } + ] + ` + "```", + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "firefunction tool calls with functools prefix", + model: "firefunction", + output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "llama3 groq single tool call with xml tags", + model: "llama3-groq-tool-use", + output: ` + {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} + `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "xlam tool calls with wrapper object", + model: "xlam", + output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 single tool call with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "qwen2.5 multiple tool calls with and without prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 plain text response no tool calls", + model: "qwen2.5", + output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + expectedToolCall: []api.ToolCall{}, + expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + }, + { + name: "qwen2.5 tool calls with trailing text", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens after call", + }, + { + name: "qwen2.5 tool calls with initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "qwen2.5 tool calls with prefix and trailing text", + model: "qwen2.5", + output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls with prefix and initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens before call", + }, + { + name: "qwen2.5 tool calls without and with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls without and with prefix and text between", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens between", + }, + { + name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", + model: "qwen2.5", + output: `hi [{"options": "foo"}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `hi [{"options": "foo"}]`, + }, + { + name: "qwen2.5 tool calls with prefix and invalid tool call", + model: "qwen2.5", + output: ` [{"options": "foo"}] `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", + model: "qwen3", + output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", + model: "qwen3", + output: `Okay, let me think what tool we should use... { "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 empty think prefix without tool prefix and invalid tool call", + model: "qwen3", + output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 empty think prefix with tool prefix and valid tool call", + model: "qwen3", + output: `{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", + model: "qwen3", + output: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with malformed tool prefix", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "model with prefix in template, no prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model with prefix in template, prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, single tool call", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "model without prefix in template, prefix in output", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, no prefix in output, tokens before", + model: "qwen2.5", + output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, prefix in output, tokens after", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens after", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens before", + model: "llama3.2", + output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model without prefix in template, prefix in output, tokens after", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + }, + } + + var tools []api.Tool + if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { + t.Fatal(err) + } + + var messages []api.Message + if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { + t.Fatal(err) + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + if err != nil { + t.Fatal(err) + } + + t.Run("template", func(t *testing.T) { + actual := &bytes.Buffer{} // Create new buffer for each test + if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("parse", func(t *testing.T) { + tp, err := NewParser(tmpl.Template) + if err != nil { + t.Fatal(err) + } + got := []api.ToolCall{} + var gotTokens strings.Builder + + tokens := strings.Fields(tt.output) + for _, tok := range tokens { + s := " " + tok + + toolCalls, content := tp.Add(s) + if len(content) > 0 { + gotTokens.WriteString(content) + } else if len(toolCalls) > 0 { + got = append(got, toolCalls...) + } + } + + // Compare tool calls if we expect any + if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { + t.Errorf("tool calls mismatch (-got +want):\n%s", diff) + } + + // Compare tokens if we expect any + stripped := strings.TrimSpace(gotTokens.String()) + if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { + t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) + t.Errorf("tokens mismatch (-got +want):\n%s", diff) + } + }) + }) + } +} diff --git a/tools/tools_utils.go b/tools/tools_utils.go new file mode 100644 index 00000000..48531b78 --- /dev/null +++ b/tools/tools_utils.go @@ -0,0 +1,227 @@ +package tools + +import ( + "bytes" + "encoding/json" + "errors" + "log/slog" + "slices" + "strings" + gotmpl "text/template" + "text/template/parse" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. +// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any +// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. +// +// Returns: +// - string: The extracted text following the first ".ToolCalls" condition found +// - bool: Whether a ".ToolCalls" condition was found in the template +func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if isToolCallsNode(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + } + } + + walk(tmpl.Tree.Root.Nodes) + return result, found +} + +// isToolCallsNode detects if a node's condition includes ".ToolCalls" +func isToolCallsNode(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +func toolPrefix(tmpl *gotmpl.Template) string { + tokenText, ok := extractToolCallsFormat(tmpl) + if !ok { + return "" + } + tokenText = strings.TrimSpace(tokenText) + tokenText = strings.ReplaceAll(tokenText, "\r", "") + tokenText = strings.ReplaceAll(tokenText, "\n", " ") + + return tokenText +} + +// toolTemplate creates a subtree from the node that ranges over .ToolCalls +// +// Returns: +// - *gotmpl.Template: The subtree containing the .ToolCalls range +// - error: Error if parsing failed +func toolTemplate(t *template.Template) (*gotmpl.Template, error) { + tmpl := t.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, errors.New("failed to find tool template") + } + + return tmpl, nil +} + +// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins +// +// Returns: +// - int: The starting index in s where the suffix overlap begins +func suffixOverlap(s, prefix string) int { + max := min(len(prefix), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, prefix[:i]) { + return len(s) - i + } + } + return -1 +} + +// extractToolArgs executes a template with a known tool call format to extract the name and arguments +// +// Returns: +// - string: The name of the tool call +// - string: The arguments of the tool call +// - error: Error if parsing failed +func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, + }, + }, + }, + }); err != nil { + return "", "", err + } + + var obj any + err = json.Unmarshal(b.Bytes(), &obj) + if err != nil { + return "", "", err + } + + var objs []map[string]any + switch v := obj.(type) { + case map[string]any: + objs = []map[string]any{v} + case []map[string]any: + objs = v + case []any: + objs = collect(v) + } + if len(objs) == 0 { + return "", "", errors.New("no template objects found") + } + + // find the keys that correspond to the name and arguments fields + for k, v := range objs[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) + return "", "", errors.New("missing required fields in tool call template") + } + + return name, arguments, nil +} + +// collect recursively traverses an object to collect all nested maps +// +// Returns: +// - []map[string]any: A slice of all nested maps found in the object +func collect(obj any) []map[string]any { + var all []map[string]any + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) + } + default: + return nil + } + + return all +} diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go new file mode 100644 index 00000000..769183b7 --- /dev/null +++ b/tools/tools_utils_test.go @@ -0,0 +1,464 @@ +package tools + +import ( + "testing" + gotmpl "text/template" + + "github.com/ollama/ollama/template" +) + +func TestExtractToolCallsFormat(t *testing.T) { + cases := []struct { + name string + template string + want string + found bool + }{ + { + name: "nil template", + template: "", + want: "", + found: false, + }, + { + name: "basic tool call with text", + template: "{{if .ToolCalls}}Hello world{{end}}", + want: "Hello world", + found: true, + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json\n", + found: true, + }, + { + name: "tool call in range", + template: "{{range .ToolCalls}}tool: {{.}}{{end}}", + want: "", + found: false, + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + found: true, + }, + { + name: "nested if without tool calls", + template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", + want: "", + found: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got, found := extractToolCallsFormat(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + if found != tc.found { + t.Errorf("got found %v, want %v", found, tc.found) + } + }) + } +} + +func TestToolPrefix(t *testing.T) { + cases := []struct { + name string + template string + want string + }{ + { + name: "basic tool call with action prefix", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action: ```json", + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools[", + }, + { + name: "tool call with angle brackets", + template: "{{if .ToolCalls}}Hello, world! {{end}}", + want: "Hello, world! ", + }, + { + name: "multiple tool call formats", + template: "{{if .ToolCalls}}[tool_call] {{end}}", + want: "[tool_call] ", + }, + { + name: "single angle bracket tool call", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + { + name: "incomplete angle bracket after tool call", + template: "{{if .ToolCalls}}[tool_call] <{{end}}", + want: "[tool_call] <", + }, + { + name: "angle bracket prefix with tool call", + template: "{{if .ToolCalls}}> {{end}}", + want: "> ", + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL] [", + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL][", + }, + { + name: "tool call with pipe delimiters", + template: "{{if .ToolCalls}}<|tool_call|>{{end}}", + want: "<|tool_call|>", + }, + { + name: "tool with no prefix", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + got := toolPrefix(tmpl) + if got != tt.want { + t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) + } + }) + } +} + +func TestToolTemplate(t *testing.T) { + cases := []struct { + name string + template string + want bool + }{ + { + name: "basic tool call range", + template: "{{range .ToolCalls}}test{{end}}", + want: true, + }, + { + name: "no tool calls", + template: "{{range .Other}}test{{end}}", + want: false, + }, + { + name: "nested tool calls", + template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", + want: true, + }, + { + name: "empty template", + template: "", + want: false, + }, + { + name: "tool calls in if statement", + template: "{{if .ToolCalls}}test{{end}}", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + parsed, err := template.Parse(tmpl.Root.String()) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + _, err = toolTemplate(parsed) + if err != nil && tt.want { + t.Errorf("toolTemplate() = %v; want %v", err, tt.want) + } + }) + } +} + +func TestSuffixOverlap(t *testing.T) { + cases := []struct { + name string + s string + d string + want int + }{ + { + name: "no overlap", + s: "hello world", + d: "", + want: -1, + }, + { + name: "full overlap", + s: "", + d: "", + want: 0, + }, + { + name: "partial overlap", + s: "text ", + d: "", + want: 5, + }, + { + name: "delimiter longer than string", + s: "", + d: "", + want: -1, + }, + { + name: "empty string", + s: "", + d: "", + want: -1, + }, + { + name: "empty delimiter", + s: "", + d: "", + want: -1, + }, + { + name: "single char overlap", + s: "test<", + d: "", + want: 4, + }, + { + name: "partial tool call", + s: "hello ", + want: 6, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := suffixOverlap(tt.s, tt.d) + if got != tt.want { + t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) + } + }) + } +} + +func TestExtractToolArgs(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool + }{ + { + name: "basic tool call with text after", + template: `{{if .ToolCalls}}tool response{{end}}`, + want: "tool response", + ok: true, + }, + { + name: "tool call with mixed content after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool call with no text after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "nested tool call", + template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "no tool calls", + template: `{{if .Something}}no tools here{{end}}`, + want: "", + ok: false, + }, + { + name: "empty template", + template: ``, + want: "", + ok: false, + }, + { + name: "multiple tool calls sections", + template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, + want: "first", + ok: true, + }, + { + name: "range over tool calls", + template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with pipe delimiters", + template: `{{if .ToolCalls}}<|tool|>{{end}}`, + want: "<|tool|>", + ok: true, + }, + { + name: "tool calls with nested template", + template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with whitespace variations", + template: `{{if .ToolCalls}} tool {{end}}`, + want: " tool ", + ok: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + got, ok := extractToolCallsFormat(tmpl) + if got != tt.want { + t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + } + if ok != tt.ok { + t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + } + }) + } +} + +func TestCollect(t *testing.T) { + cases := []struct { + name string + obj any + want []map[string]any + }{ + { + name: "simple map", + obj: map[string]any{ + "key": "value", + }, + want: []map[string]any{ + {"key": "value"}, + }, + }, + { + name: "nested map", + obj: map[string]any{ + "outer": map[string]any{ + "inner": "value", + }, + }, + want: []map[string]any{ + {"outer": map[string]any{"inner": "value"}}, + {"inner": "value"}, + }, + }, + { + name: "array of maps", + obj: []any{ + map[string]any{"key1": "val1"}, + map[string]any{"key2": "val2"}, + }, + want: []map[string]any{ + {"key1": "val1"}, + {"key2": "val2"}, + }, + }, + { + name: "deeply nested", + obj: map[string]any{ + "l1": map[string]any{ + "l2": map[string]any{ + "l3": "value", + }, + }, + }, + want: []map[string]any{ + {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, + {"l2": map[string]any{"l3": "value"}}, + {"l3": "value"}, + }, + }, + { + name: "non-map value", + obj: "string", + want: nil, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := collect(tt.obj) + if len(got) != len(tt.want) { + t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) + return + } + + // Compare each map in the result + for i := range tt.want { + if !mapsEqual(got[i], tt.want[i]) { + t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +// mapsEqual compares two maps for deep equality +func mapsEqual(m1, m2 map[string]any) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + v2, ok := m2[k] + if !ok { + return false + } + switch val1 := v1.(type) { + case map[string]any: + val2, ok := v2.(map[string]any) + if !ok || !mapsEqual(val1, val2) { + return false + } + default: + if v1 != v2 { + return false + } + } + } + return true +}