1package fantasy
2
3import (
4 "context"
5 "fmt"
6 "reflect"
7 "testing"
8
9 "github.com/stretchr/testify/require"
10)
11
12// Example of a simple typed tool using the function approach
13type CalculatorInput struct {
14 Expression string `json:"expression" description:"Mathematical expression to evaluate"`
15}
16
17func TestTypedToolFuncExample(t *testing.T) {
18 // Create a typed tool using the function API
19 tool := NewAgentTool(
20 "calculator",
21 "Evaluates simple mathematical expressions",
22 func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
23 if input.Expression == "2+2" {
24 return NewTextResponse("4"), nil
25 }
26 return NewTextErrorResponse("unsupported expression"), nil
27 },
28 )
29
30 // Check the tool info
31 info := tool.Info()
32 require.Equal(t, "calculator", info.Name)
33 require.Len(t, info.Required, 1)
34 require.Equal(t, "expression", info.Required[0])
35
36 // Test execution
37 call := ToolCall{
38 ID: "test-1",
39 Name: "calculator",
40 Input: `{"expression": "2+2"}`,
41 }
42
43 result, err := tool.Run(context.Background(), call)
44 require.NoError(t, err)
45 require.Equal(t, "4", result.Content)
46 require.False(t, result.IsError)
47}
48
49func TestEnumToolExample(t *testing.T) {
50 type WeatherInput struct {
51 Location string `json:"location" description:"City name"`
52 Units string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
53 }
54
55 // Create a weather tool with enum support
56 tool := NewAgentTool(
57 "weather",
58 "Gets current weather for a location",
59 func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
60 temp := "22°C"
61 if input.Units == "fahrenheit" {
62 temp = "72°F"
63 }
64 return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
65 },
66 )
67
68 // Check that the schema includes enum values
69 info := tool.Info()
70 unitsParam, ok := info.Parameters["units"].(map[string]any)
71 require.True(t, ok, "Expected units parameter to exist")
72 enumValues, ok := unitsParam["enum"].([]any)
73 require.True(t, ok)
74 require.Len(t, enumValues, 2)
75
76 // Test execution with enum value
77 call := ToolCall{
78 ID: "test-2",
79 Name: "weather",
80 Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
81 }
82
83 result, err := tool.Run(context.Background(), call)
84 require.NoError(t, err)
85 require.Contains(t, result.Content, "San Francisco")
86 require.Contains(t, result.Content, "72°F")
87}
88
89func TestEnumSupport(t *testing.T) {
90 // Test enum via struct tags
91 type WeatherInput struct {
92 Location string `json:"location" description:"City name"`
93 Units string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
94 Format string `json:"format,omitempty" enum:"json,xml,text"`
95 }
96
97 schema := generateSchema(reflect.TypeOf(WeatherInput{}))
98
99 require.Equal(t, "object", schema.Type)
100
101 // Check units field has enum values
102 unitsSchema := schema.Properties["units"]
103 require.NotNil(t, unitsSchema, "Expected units property to exist")
104 require.Len(t, unitsSchema.Enum, 3)
105 expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
106 for i, expected := range expectedUnits {
107 require.Equal(t, expected, unitsSchema.Enum[i])
108 }
109
110 // Check required fields (format should not be required due to omitempty)
111 expectedRequired := []string{"location", "units"}
112 require.Len(t, schema.Required, len(expectedRequired))
113}
114
115func TestSchemaToParameters(t *testing.T) {
116 schema := Schema{
117 Type: "object",
118 Properties: map[string]*Schema{
119 "name": {
120 Type: "string",
121 Description: "The name field",
122 },
123 "age": {
124 Type: "integer",
125 Minimum: func() *float64 { v := 0.0; return &v }(),
126 Maximum: func() *float64 { v := 120.0; return &v }(),
127 },
128 "tags": {
129 Type: "array",
130 Items: &Schema{
131 Type: "string",
132 },
133 },
134 "priority": {
135 Type: "string",
136 Enum: []any{"low", "medium", "high"},
137 },
138 },
139 Required: []string{"name"},
140 }
141
142 params := schemaToParameters(schema)
143
144 // Check name parameter
145 nameParam, ok := params["name"].(map[string]any)
146 require.True(t, ok, "Expected name parameter to exist")
147 require.Equal(t, "string", nameParam["type"])
148 require.Equal(t, "The name field", nameParam["description"])
149
150 // Check age parameter with min/max
151 ageParam, ok := params["age"].(map[string]any)
152 require.True(t, ok, "Expected age parameter to exist")
153 require.Equal(t, "integer", ageParam["type"])
154 require.Equal(t, 0.0, ageParam["minimum"])
155 require.Equal(t, 120.0, ageParam["maximum"])
156
157 // Check priority parameter with enum
158 priorityParam, ok := params["priority"].(map[string]any)
159 require.True(t, ok, "Expected priority parameter to exist")
160 require.Equal(t, "string", priorityParam["type"])
161 enumValues, ok := priorityParam["enum"].([]any)
162 require.True(t, ok)
163 require.Len(t, enumValues, 3)
164}
165
166func TestGenerateSchemaBasicTypes(t *testing.T) {
167 t.Parallel()
168
169 tests := []struct {
170 name string
171 input any
172 expected Schema
173 }{
174 {
175 name: "string type",
176 input: "",
177 expected: Schema{Type: "string"},
178 },
179 {
180 name: "int type",
181 input: 0,
182 expected: Schema{Type: "integer"},
183 },
184 {
185 name: "int64 type",
186 input: int64(0),
187 expected: Schema{Type: "integer"},
188 },
189 {
190 name: "uint type",
191 input: uint(0),
192 expected: Schema{Type: "integer"},
193 },
194 {
195 name: "float64 type",
196 input: 0.0,
197 expected: Schema{Type: "number"},
198 },
199 {
200 name: "float32 type",
201 input: float32(0.0),
202 expected: Schema{Type: "number"},
203 },
204 {
205 name: "bool type",
206 input: false,
207 expected: Schema{Type: "boolean"},
208 },
209 }
210
211 for _, tt := range tests {
212 t.Run(tt.name, func(t *testing.T) {
213 t.Parallel()
214 schema := generateSchema(reflect.TypeOf(tt.input))
215 require.Equal(t, tt.expected.Type, schema.Type)
216 })
217 }
218}
219
220func TestGenerateSchemaArrayTypes(t *testing.T) {
221 t.Parallel()
222
223 tests := []struct {
224 name string
225 input any
226 expected Schema
227 }{
228 {
229 name: "string slice",
230 input: []string{},
231 expected: Schema{
232 Type: "array",
233 Items: &Schema{Type: "string"},
234 },
235 },
236 {
237 name: "int slice",
238 input: []int{},
239 expected: Schema{
240 Type: "array",
241 Items: &Schema{Type: "integer"},
242 },
243 },
244 {
245 name: "string array",
246 input: [3]string{},
247 expected: Schema{
248 Type: "array",
249 Items: &Schema{Type: "string"},
250 },
251 },
252 }
253
254 for _, tt := range tests {
255 t.Run(tt.name, func(t *testing.T) {
256 t.Parallel()
257 schema := generateSchema(reflect.TypeOf(tt.input))
258 require.Equal(t, tt.expected.Type, schema.Type)
259 require.NotNil(t, schema.Items, "Expected items schema to exist")
260 require.Equal(t, tt.expected.Items.Type, schema.Items.Type)
261 })
262 }
263}
264
265func TestGenerateSchemaMapTypes(t *testing.T) {
266 t.Parallel()
267
268 tests := []struct {
269 name string
270 input any
271 expected string
272 }{
273 {
274 name: "string to string map",
275 input: map[string]string{},
276 expected: "object",
277 },
278 {
279 name: "string to int map",
280 input: map[string]int{},
281 expected: "object",
282 },
283 {
284 name: "int to string map",
285 input: map[int]string{},
286 expected: "object",
287 },
288 }
289
290 for _, tt := range tests {
291 t.Run(tt.name, func(t *testing.T) {
292 t.Parallel()
293 schema := generateSchema(reflect.TypeOf(tt.input))
294 require.Equal(t, tt.expected, schema.Type)
295 })
296 }
297}
298
299func TestGenerateSchemaStructTypes(t *testing.T) {
300 t.Parallel()
301
302 type SimpleStruct struct {
303 Name string `json:"name" description:"The name field"`
304 Age int `json:"age"`
305 }
306
307 type StructWithOmitEmpty struct {
308 Required string `json:"required"`
309 Optional string `json:"optional,omitempty"`
310 }
311
312 type StructWithJSONIgnore struct {
313 Visible string `json:"visible"`
314 Hidden string `json:"-"`
315 }
316
317 type StructWithoutJSONTags struct {
318 FirstName string
319 LastName string
320 }
321
322 tests := []struct {
323 name string
324 input any
325 validate func(t *testing.T, schema Schema)
326 }{
327 {
328 name: "simple struct",
329 input: SimpleStruct{},
330 validate: func(t *testing.T, schema Schema) {
331 require.Equal(t, "object", schema.Type)
332 require.Len(t, schema.Properties, 2)
333 require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
334 require.Equal(t, "The name field", schema.Properties["name"].Description)
335 require.Len(t, schema.Required, 2)
336 },
337 },
338 {
339 name: "struct with omitempty",
340 input: StructWithOmitEmpty{},
341 validate: func(t *testing.T, schema Schema) {
342 require.Len(t, schema.Required, 1)
343 require.Equal(t, "required", schema.Required[0])
344 },
345 },
346 {
347 name: "struct with json ignore",
348 input: StructWithJSONIgnore{},
349 validate: func(t *testing.T, schema Schema) {
350 require.Len(t, schema.Properties, 1)
351 require.NotNil(t, schema.Properties["visible"], "Expected visible property to exist")
352 require.Nil(t, schema.Properties["hidden"], "Expected hidden property to not exist")
353 },
354 },
355 {
356 name: "struct without json tags",
357 input: StructWithoutJSONTags{},
358 validate: func(t *testing.T, schema Schema) {
359 require.NotNil(t, schema.Properties["first_name"], "Expected first_name property to exist")
360 require.NotNil(t, schema.Properties["last_name"], "Expected last_name property to exist")
361 },
362 },
363 }
364
365 for _, tt := range tests {
366 t.Run(tt.name, func(t *testing.T) {
367 t.Parallel()
368 schema := generateSchema(reflect.TypeOf(tt.input))
369 tt.validate(t, schema)
370 })
371 }
372}
373
374func TestGenerateSchemaPointerTypes(t *testing.T) {
375 t.Parallel()
376
377 type StructWithPointers struct {
378 Name *string `json:"name"`
379 Age *int `json:"age"`
380 }
381
382 schema := generateSchema(reflect.TypeOf(StructWithPointers{}))
383
384 require.Equal(t, "object", schema.Type)
385
386 require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
387 require.Equal(t, "string", schema.Properties["name"].Type)
388
389 require.NotNil(t, schema.Properties["age"], "Expected age property to exist")
390 require.Equal(t, "integer", schema.Properties["age"].Type)
391}
392
393func TestGenerateSchemaNestedStructs(t *testing.T) {
394 t.Parallel()
395
396 type Address struct {
397 Street string `json:"street"`
398 City string `json:"city"`
399 }
400
401 type Person struct {
402 Name string `json:"name"`
403 Address Address `json:"address"`
404 }
405
406 schema := generateSchema(reflect.TypeOf(Person{}))
407
408 require.Equal(t, "object", schema.Type)
409
410 require.NotNil(t, schema.Properties["address"], "Expected address property to exist")
411
412 addressSchema := schema.Properties["address"]
413 require.Equal(t, "object", addressSchema.Type)
414
415 require.NotNil(t, addressSchema.Properties["street"], "Expected street property in address to exist")
416 require.NotNil(t, addressSchema.Properties["city"], "Expected city property in address to exist")
417}
418
419func TestGenerateSchemaRecursiveStructs(t *testing.T) {
420 t.Parallel()
421
422 type Node struct {
423 Value string `json:"value"`
424 Next *Node `json:"next,omitempty"`
425 }
426
427 schema := generateSchema(reflect.TypeOf(Node{}))
428
429 require.Equal(t, "object", schema.Type)
430
431 require.NotNil(t, schema.Properties["value"], "Expected value property to exist")
432
433 require.NotNil(t, schema.Properties["next"], "Expected next property to exist")
434
435 // The recursive reference should be handled gracefully
436 nextSchema := schema.Properties["next"]
437 require.Equal(t, "object", nextSchema.Type)
438}
439
440func TestGenerateSchemaWithEnumTags(t *testing.T) {
441 t.Parallel()
442
443 type ConfigInput struct {
444 Level string `json:"level" enum:"debug,info,warn,error" description:"Log level"`
445 Format string `json:"format" enum:"json,text"`
446 Optional string `json:"optional,omitempty" enum:"a,b,c"`
447 }
448
449 schema := generateSchema(reflect.TypeOf(ConfigInput{}))
450
451 // Check level field
452 levelSchema := schema.Properties["level"]
453 require.NotNil(t, levelSchema, "Expected level property to exist")
454 require.Len(t, levelSchema.Enum, 4)
455 expectedLevels := []string{"debug", "info", "warn", "error"}
456 for i, expected := range expectedLevels {
457 require.Equal(t, expected, levelSchema.Enum[i])
458 }
459
460 // Check format field
461 formatSchema := schema.Properties["format"]
462 require.NotNil(t, formatSchema, "Expected format property to exist")
463 require.Len(t, formatSchema.Enum, 2)
464
465 // Check required fields (optional should not be required due to omitempty)
466 expectedRequired := []string{"level", "format"}
467 require.Len(t, schema.Required, len(expectedRequired))
468}
469
470func TestGenerateSchemaComplexTypes(t *testing.T) {
471 t.Parallel()
472
473 type ComplexInput struct {
474 StringSlice []string `json:"string_slice"`
475 IntMap map[string]int `json:"int_map"`
476 NestedSlice []map[string]string `json:"nested_slice"`
477 Interface any `json:"interface"`
478 }
479
480 schema := generateSchema(reflect.TypeOf(ComplexInput{}))
481
482 // Check string slice
483 stringSliceSchema := schema.Properties["string_slice"]
484 require.NotNil(t, stringSliceSchema, "Expected string_slice property to exist")
485 require.Equal(t, "array", stringSliceSchema.Type)
486 require.Equal(t, "string", stringSliceSchema.Items.Type)
487
488 // Check int map
489 intMapSchema := schema.Properties["int_map"]
490 require.NotNil(t, intMapSchema, "Expected int_map property to exist")
491 require.Equal(t, "object", intMapSchema.Type)
492
493 // Check nested slice
494 nestedSliceSchema := schema.Properties["nested_slice"]
495 require.NotNil(t, nestedSliceSchema, "Expected nested_slice property to exist")
496 require.Equal(t, "array", nestedSliceSchema.Type)
497 require.Equal(t, "", nestedSliceSchema.Items.Type)
498
499 // Check interface
500 interfaceSchema := schema.Properties["interface"]
501 require.NotNil(t, interfaceSchema, "Expected interface property to exist")
502 require.Equal(t, "object", interfaceSchema.Type)
503}
504
505func TestToSnakeCase(t *testing.T) {
506 t.Parallel()
507
508 tests := []struct {
509 input string
510 expected string
511 }{
512 {"FirstName", "first_name"},
513 {"XMLHttpRequest", "x_m_l_http_request"},
514 {"ID", "i_d"},
515 {"HTTPSProxy", "h_t_t_p_s_proxy"},
516 {"simple", "simple"},
517 {"", ""},
518 {"A", "a"},
519 {"AB", "a_b"},
520 {"CamelCase", "camel_case"},
521 }
522
523 for _, tt := range tests {
524 t.Run(tt.input, func(t *testing.T) {
525 t.Parallel()
526 result := toSnakeCase(tt.input)
527 require.Equal(t, tt.expected, result, "toSnakeCase(%s)", tt.input)
528 })
529 }
530}
531
532func TestSchemaToParametersEdgeCases(t *testing.T) {
533 t.Parallel()
534
535 tests := []struct {
536 name string
537 schema Schema
538 expected map[string]any
539 }{
540 {
541 name: "non-object schema",
542 schema: Schema{
543 Type: "string",
544 },
545 expected: map[string]any{},
546 },
547 {
548 name: "object with no properties",
549 schema: Schema{
550 Type: "object",
551 Properties: nil,
552 },
553 expected: map[string]any{},
554 },
555 {
556 name: "object with empty properties",
557 schema: Schema{
558 Type: "object",
559 Properties: map[string]*Schema{},
560 },
561 expected: map[string]any{},
562 },
563 {
564 name: "schema with all constraint types",
565 schema: Schema{
566 Type: "object",
567 Properties: map[string]*Schema{
568 "text": {
569 Type: "string",
570 Format: "email",
571 MinLength: func() *int { v := 5; return &v }(),
572 MaxLength: func() *int { v := 100; return &v }(),
573 },
574 "number": {
575 Type: "number",
576 Minimum: func() *float64 { v := 0.0; return &v }(),
577 Maximum: func() *float64 { v := 100.0; return &v }(),
578 },
579 },
580 },
581 expected: map[string]any{
582 "text": map[string]any{
583 "type": "string",
584 "format": "email",
585 "minLength": 5,
586 "maxLength": 100,
587 },
588 "number": map[string]any{
589 "type": "number",
590 "minimum": 0.0,
591 "maximum": 100.0,
592 },
593 },
594 },
595 }
596
597 for _, tt := range tests {
598 t.Run(tt.name, func(t *testing.T) {
599 t.Parallel()
600 result := schemaToParameters(tt.schema)
601 require.Len(t, result, len(tt.expected))
602 for key, expectedValue := range tt.expected {
603 require.NotNil(t, result[key], "Expected parameter %s to exist", key)
604 // Deep comparison would be complex, so we'll check key properties
605 resultParam := result[key].(map[string]any)
606 expectedParam := expectedValue.(map[string]any)
607 for propKey, propValue := range expectedParam {
608 require.Equal(t, propValue, resultParam[propKey], "Expected %s.%s", key, propKey)
609 }
610 }
611 })
612 }
613}