tool_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"fmt"
  6	"reflect"
  7	"strings"
  8	"testing"
  9)
 10
 11// Example of a simple typed tool using the function approach
 12type CalculatorInput struct {
 13	Expression string `json:"expression" description:"Mathematical expression to evaluate"`
 14}
 15
 16func TestTypedToolFuncExample(t *testing.T) {
 17	// Create a typed tool using the function API
 18	tool := NewTypedToolFunc(
 19		"calculator",
 20		"Evaluates simple mathematical expressions",
 21		func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
 22			if input.Expression == "2+2" {
 23				return NewTextResponse("4"), nil
 24			}
 25			return NewTextErrorResponse("unsupported expression"), nil
 26		},
 27	)
 28
 29	// Check the tool info
 30	info := tool.Info()
 31	if info.Name != "calculator" {
 32		t.Errorf("Expected tool name 'calculator', got %s", info.Name)
 33	}
 34	if len(info.Required) != 1 || info.Required[0] != "expression" {
 35		t.Errorf("Expected required field 'expression', got %v", info.Required)
 36	}
 37
 38	// Test execution
 39	call := ToolCall{
 40		ID:    "test-1",
 41		Name:  "calculator",
 42		Input: `{"expression": "2+2"}`,
 43	}
 44
 45	result, err := tool.Run(context.Background(), call)
 46	if err != nil {
 47		t.Errorf("Unexpected error: %v", err)
 48	}
 49	if result.Content != "4" {
 50		t.Errorf("Expected result '4', got %s", result.Content)
 51	}
 52	if result.IsError {
 53		t.Errorf("Expected successful result, got error")
 54	}
 55}
 56
 57func TestEnumToolExample(t *testing.T) {
 58	type WeatherInput struct {
 59		Location string `json:"location" description:"City name"`
 60		Units    string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
 61	}
 62
 63	// Create a weather tool with enum support
 64	tool := NewTypedToolFunc(
 65		"weather",
 66		"Gets current weather for a location",
 67		func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
 68			temp := "22°C"
 69			if input.Units == "fahrenheit" {
 70				temp = "72°F"
 71			}
 72			return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
 73		},
 74	)
 75
 76	// Check that the schema includes enum values
 77	info := tool.Info()
 78	unitsParam, ok := info.Parameters["units"].(map[string]any)
 79	if !ok {
 80		t.Fatal("Expected units parameter to exist")
 81	}
 82	enumValues, ok := unitsParam["enum"].([]any)
 83	if !ok || len(enumValues) != 2 {
 84		t.Errorf("Expected 2 enum values, got %v", enumValues)
 85	}
 86
 87	// Test execution with enum value
 88	call := ToolCall{
 89		ID:    "test-2",
 90		Name:  "weather",
 91		Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
 92	}
 93
 94	result, err := tool.Run(context.Background(), call)
 95	if err != nil {
 96		t.Errorf("Unexpected error: %v", err)
 97	}
 98	if !strings.Contains(result.Content, "San Francisco") {
 99		t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
100	}
101	if !strings.Contains(result.Content, "72°F") {
102		t.Errorf("Expected result to contain '72°F', got %s", result.Content)
103	}
104}
105
106func TestEnumSupport(t *testing.T) {
107	// Test enum via struct tags
108	type WeatherInput struct {
109		Location string `json:"location" description:"City name"`
110		Units    string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
111		Format   string `json:"format,omitempty" enum:"json,xml,text"`
112	}
113
114	schema := generateSchema(reflect.TypeOf(WeatherInput{}))
115
116	if schema.Type != "object" {
117		t.Errorf("Expected schema type 'object', got %s", schema.Type)
118	}
119
120	// Check units field has enum values
121	unitsSchema := schema.Properties["units"]
122	if unitsSchema == nil {
123		t.Fatal("Expected units property to exist")
124	}
125	if len(unitsSchema.Enum) != 3 {
126		t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
127	}
128	expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
129	for i, expected := range expectedUnits {
130		if unitsSchema.Enum[i] != expected {
131			t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
132		}
133	}
134
135	// Check required fields (format should not be required due to omitempty)
136	expectedRequired := []string{"location", "units"}
137	if len(schema.Required) != len(expectedRequired) {
138		t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
139	}
140}
141
142func TestSchemaToParameters(t *testing.T) {
143	schema := Schema{
144		Type: "object",
145		Properties: map[string]*Schema{
146			"name": {
147				Type:        "string",
148				Description: "The name field",
149			},
150			"age": {
151				Type:    "integer",
152				Minimum: func() *float64 { v := 0.0; return &v }(),
153				Maximum: func() *float64 { v := 120.0; return &v }(),
154			},
155			"tags": {
156				Type: "array",
157				Items: &Schema{
158					Type: "string",
159				},
160			},
161			"priority": {
162				Type: "string",
163				Enum: []any{"low", "medium", "high"},
164			},
165		},
166		Required: []string{"name"},
167	}
168
169	params := schemaToParameters(schema)
170
171	// Check name parameter
172	nameParam, ok := params["name"].(map[string]any)
173	if !ok {
174		t.Fatal("Expected name parameter to exist")
175	}
176	if nameParam["type"] != "string" {
177		t.Errorf("Expected name type 'string', got %v", nameParam["type"])
178	}
179	if nameParam["description"] != "The name field" {
180		t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
181	}
182
183	// Check age parameter with min/max
184	ageParam, ok := params["age"].(map[string]any)
185	if !ok {
186		t.Fatal("Expected age parameter to exist")
187	}
188	if ageParam["type"] != "integer" {
189		t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
190	}
191	if ageParam["minimum"] != 0.0 {
192		t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
193	}
194	if ageParam["maximum"] != 120.0 {
195		t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
196	}
197
198	// Check priority parameter with enum
199	priorityParam, ok := params["priority"].(map[string]any)
200	if !ok {
201		t.Fatal("Expected priority parameter to exist")
202	}
203	if priorityParam["type"] != "string" {
204		t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
205	}
206	enumValues, ok := priorityParam["enum"].([]any)
207	if !ok || len(enumValues) != 3 {
208		t.Errorf("Expected 3 enum values, got %v", enumValues)
209	}
210}
211