mirror of
https://github.com/zebrajr/ollama.git
synced 2025-12-06 00:19:51 +01:00
tools: resiliency upgrade to name and arg extraction from template (#10917)
This commit is contained in:
parent
aaa7818000
commit
65f10c2823
|
|
@ -166,31 +166,26 @@ func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error)
|
|||
return "", "", err
|
||||
}
|
||||
|
||||
var obj any
|
||||
err = json.Unmarshal(b.Bytes(), &obj)
|
||||
if err != nil {
|
||||
// Extract JSON object between curly braces
|
||||
// JSON arrays are also valid as they will not be repeated in the template
|
||||
output := b.String()
|
||||
start := strings.Index(output, "{")
|
||||
end := strings.LastIndex(output, "}")
|
||||
if start == -1 || end == -1 || start > end {
|
||||
return "", "", errors.New("no valid JSON object found in template output")
|
||||
}
|
||||
jsonStr := output[start : end+1]
|
||||
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
switch v := obj.(type) {
|
||||
case map[string]any:
|
||||
objs = []map[string]any{v}
|
||||
case []map[string]any:
|
||||
objs = v
|
||||
case []any:
|
||||
objs = collect(v)
|
||||
}
|
||||
if len(objs) == 0 {
|
||||
return "", "", errors.New("no template objects found")
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
for k, v := range objs[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
// Find name and arguments fields
|
||||
for k, v := range obj {
|
||||
if str, ok := v.(string); ok && str == "@@name@@" {
|
||||
name = k
|
||||
case map[string]any:
|
||||
} else if _, ok := v.(map[string]any); ok {
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -271,74 +271,99 @@ func TestExtractToolArgs(t *testing.T) {
|
|||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
ok bool
|
||||
wantName string
|
||||
wantArgs string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with text after",
|
||||
template: `{{if .ToolCalls}}tool response{{end}}`,
|
||||
want: "tool response",
|
||||
ok: true,
|
||||
name: "basic tool call",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with mixed content after",
|
||||
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
||||
want: "<tool_call>",
|
||||
ok: true,
|
||||
name: "tool call with whitespace",
|
||||
template: `{{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}}
|
||||
{{end}}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with no text after",
|
||||
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "nested tool call",
|
||||
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
||||
want: "[TOOL_CALL]",
|
||||
ok: true,
|
||||
name: "tool call with extra content",
|
||||
template: `Before {{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: `{{if .Something}}no tools here{{end}}`,
|
||||
want: "",
|
||||
ok: false,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: ``,
|
||||
want: "",
|
||||
ok: false,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls sections",
|
||||
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
||||
want: "first",
|
||||
ok: true,
|
||||
name: "prefix within tool call",
|
||||
template: `{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
<tool_call>
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
</tool_call>{{ end }}{{- end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "range over tool calls",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
name: "JSON array",
|
||||
template: `{{ range .ToolCalls }}
|
||||
[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool calls with pipe delimiters",
|
||||
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
||||
want: "<|tool|>",
|
||||
ok: true,
|
||||
name: "invalid JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "tool calls with nested template",
|
||||
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
name: "missing name field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "tool calls with whitespace variations",
|
||||
template: `{{if .ToolCalls}} tool {{end}}`,
|
||||
want: " tool ",
|
||||
ok: true,
|
||||
name: "missing arguments field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}"}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -349,12 +374,20 @@ func TestExtractToolArgs(t *testing.T) {
|
|||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, ok := extractToolCallsFormat(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||
gotName, gotArgs, err := extractToolArgs(tmpl)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if ok != tt.ok {
|
||||
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotArgs != tt.wantArgs {
|
||||
t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user