1package ai
2
3import (
4 "context"
5 "fmt"
6 "reflect"
7 "strings"
8 "testing"
9)
10
11// Example of a simple typed tool using the function approach
12type CalculatorInput struct {
13 Expression string `json:"expression" description:"Mathematical expression to evaluate"`
14}
15
16func TestTypedToolFuncExample(t *testing.T) {
17 // Create a typed tool using the function API
18 tool := NewTypedToolFunc(
19 "calculator",
20 "Evaluates simple mathematical expressions",
21 func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
22 if input.Expression == "2+2" {
23 return NewTextResponse("4"), nil
24 }
25 return NewTextErrorResponse("unsupported expression"), nil
26 },
27 )
28
29 // Check the tool info
30 info := tool.Info()
31 if info.Name != "calculator" {
32 t.Errorf("Expected tool name 'calculator', got %s", info.Name)
33 }
34 if len(info.Required) != 1 || info.Required[0] != "expression" {
35 t.Errorf("Expected required field 'expression', got %v", info.Required)
36 }
37
38 // Test execution
39 call := ToolCall{
40 ID: "test-1",
41 Name: "calculator",
42 Input: `{"expression": "2+2"}`,
43 }
44
45 result, err := tool.Run(context.Background(), call)
46 if err != nil {
47 t.Errorf("Unexpected error: %v", err)
48 }
49 if result.Content != "4" {
50 t.Errorf("Expected result '4', got %s", result.Content)
51 }
52 if result.IsError {
53 t.Errorf("Expected successful result, got error")
54 }
55}
56
57func TestEnumToolExample(t *testing.T) {
58 type WeatherInput struct {
59 Location string `json:"location" description:"City name"`
60 Units string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
61 }
62
63 // Create a weather tool with enum support
64 tool := NewTypedToolFunc(
65 "weather",
66 "Gets current weather for a location",
67 func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
68 temp := "22°C"
69 if input.Units == "fahrenheit" {
70 temp = "72°F"
71 }
72 return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
73 },
74 )
75
76 // Check that the schema includes enum values
77 info := tool.Info()
78 unitsParam, ok := info.Parameters["units"].(map[string]any)
79 if !ok {
80 t.Fatal("Expected units parameter to exist")
81 }
82 enumValues, ok := unitsParam["enum"].([]any)
83 if !ok || len(enumValues) != 2 {
84 t.Errorf("Expected 2 enum values, got %v", enumValues)
85 }
86
87 // Test execution with enum value
88 call := ToolCall{
89 ID: "test-2",
90 Name: "weather",
91 Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
92 }
93
94 result, err := tool.Run(context.Background(), call)
95 if err != nil {
96 t.Errorf("Unexpected error: %v", err)
97 }
98 if !strings.Contains(result.Content, "San Francisco") {
99 t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
100 }
101 if !strings.Contains(result.Content, "72°F") {
102 t.Errorf("Expected result to contain '72°F', got %s", result.Content)
103 }
104}
105
106func TestEnumSupport(t *testing.T) {
107 // Test enum via struct tags
108 type WeatherInput struct {
109 Location string `json:"location" description:"City name"`
110 Units string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
111 Format string `json:"format,omitempty" enum:"json,xml,text"`
112 }
113
114 schema := generateSchema(reflect.TypeOf(WeatherInput{}))
115
116 if schema.Type != "object" {
117 t.Errorf("Expected schema type 'object', got %s", schema.Type)
118 }
119
120 // Check units field has enum values
121 unitsSchema := schema.Properties["units"]
122 if unitsSchema == nil {
123 t.Fatal("Expected units property to exist")
124 }
125 if len(unitsSchema.Enum) != 3 {
126 t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
127 }
128 expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
129 for i, expected := range expectedUnits {
130 if unitsSchema.Enum[i] != expected {
131 t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
132 }
133 }
134
135 // Check required fields (format should not be required due to omitempty)
136 expectedRequired := []string{"location", "units"}
137 if len(schema.Required) != len(expectedRequired) {
138 t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
139 }
140}
141
142func TestSchemaToParameters(t *testing.T) {
143 schema := Schema{
144 Type: "object",
145 Properties: map[string]*Schema{
146 "name": {
147 Type: "string",
148 Description: "The name field",
149 },
150 "age": {
151 Type: "integer",
152 Minimum: func() *float64 { v := 0.0; return &v }(),
153 Maximum: func() *float64 { v := 120.0; return &v }(),
154 },
155 "tags": {
156 Type: "array",
157 Items: &Schema{
158 Type: "string",
159 },
160 },
161 "priority": {
162 Type: "string",
163 Enum: []any{"low", "medium", "high"},
164 },
165 },
166 Required: []string{"name"},
167 }
168
169 params := schemaToParameters(schema)
170
171 // Check name parameter
172 nameParam, ok := params["name"].(map[string]any)
173 if !ok {
174 t.Fatal("Expected name parameter to exist")
175 }
176 if nameParam["type"] != "string" {
177 t.Errorf("Expected name type 'string', got %v", nameParam["type"])
178 }
179 if nameParam["description"] != "The name field" {
180 t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
181 }
182
183 // Check age parameter with min/max
184 ageParam, ok := params["age"].(map[string]any)
185 if !ok {
186 t.Fatal("Expected age parameter to exist")
187 }
188 if ageParam["type"] != "integer" {
189 t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
190 }
191 if ageParam["minimum"] != 0.0 {
192 t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
193 }
194 if ageParam["maximum"] != 120.0 {
195 t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
196 }
197
198 // Check priority parameter with enum
199 priorityParam, ok := params["priority"].(map[string]any)
200 if !ok {
201 t.Fatal("Expected priority parameter to exist")
202 }
203 if priorityParam["type"] != "string" {
204 t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
205 }
206 enumValues, ok := priorityParam["enum"].([]any)
207 if !ok || len(enumValues) != 3 {
208 t.Errorf("Expected 3 enum values, got %v", enumValues)
209 }
210}
211