1package ai
2
3import (
4 "context"
5 "fmt"
6 "reflect"
7 "strings"
8 "testing"
9)
10
11// Example of a simple typed tool using the function approach
12type CalculatorInput struct {
13 Expression string `json:"expression" description:"Mathematical expression to evaluate"`
14}
15
16func TestTypedToolFuncExample(t *testing.T) {
17 // Create a typed tool using the function API
18 tool := NewAgentTool(
19 "calculator",
20 "Evaluates simple mathematical expressions",
21 func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
22 if input.Expression == "2+2" {
23 return NewTextResponse("4"), nil
24 }
25 return NewTextErrorResponse("unsupported expression"), nil
26 },
27 )
28
29 // Check the tool info
30 info := tool.Info()
31 if info.Name != "calculator" {
32 t.Errorf("Expected tool name 'calculator', got %s", info.Name)
33 }
34 if len(info.Required) != 1 || info.Required[0] != "expression" {
35 t.Errorf("Expected required field 'expression', got %v", info.Required)
36 }
37
38 // Test execution
39 call := ToolCall{
40 ID: "test-1",
41 Name: "calculator",
42 Input: `{"expression": "2+2"}`,
43 }
44
45 result, err := tool.Run(context.Background(), call)
46 if err != nil {
47 t.Errorf("Unexpected error: %v", err)
48 }
49 if result.Content != "4" {
50 t.Errorf("Expected result '4', got %s", result.Content)
51 }
52 if result.IsError {
53 t.Errorf("Expected successful result, got error")
54 }
55}
56
57func TestEnumToolExample(t *testing.T) {
58 type WeatherInput struct {
59 Location string `json:"location" description:"City name"`
60 Units string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
61 }
62
63 // Create a weather tool with enum support
64 tool := NewAgentTool(
65 "weather",
66 "Gets current weather for a location",
67 func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
68 temp := "22°C"
69 if input.Units == "fahrenheit" {
70 temp = "72°F"
71 }
72 return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
73 },
74 )
75
76 // Check that the schema includes enum values
77 info := tool.Info()
78 unitsParam, ok := info.Parameters["units"].(map[string]any)
79 if !ok {
80 t.Fatal("Expected units parameter to exist")
81 }
82 enumValues, ok := unitsParam["enum"].([]any)
83 if !ok || len(enumValues) != 2 {
84 t.Errorf("Expected 2 enum values, got %v", enumValues)
85 }
86
87 // Test execution with enum value
88 call := ToolCall{
89 ID: "test-2",
90 Name: "weather",
91 Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
92 }
93
94 result, err := tool.Run(context.Background(), call)
95 if err != nil {
96 t.Errorf("Unexpected error: %v", err)
97 }
98 if !strings.Contains(result.Content, "San Francisco") {
99 t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
100 }
101 if !strings.Contains(result.Content, "72°F") {
102 t.Errorf("Expected result to contain '72°F', got %s", result.Content)
103 }
104}
105
106func TestEnumSupport(t *testing.T) {
107 // Test enum via struct tags
108 type WeatherInput struct {
109 Location string `json:"location" description:"City name"`
110 Units string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
111 Format string `json:"format,omitempty" enum:"json,xml,text"`
112 }
113
114 schema := generateSchema(reflect.TypeOf(WeatherInput{}))
115
116 if schema.Type != "object" {
117 t.Errorf("Expected schema type 'object', got %s", schema.Type)
118 }
119
120 // Check units field has enum values
121 unitsSchema := schema.Properties["units"]
122 if unitsSchema == nil {
123 t.Fatal("Expected units property to exist")
124 }
125 if len(unitsSchema.Enum) != 3 {
126 t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
127 }
128 expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
129 for i, expected := range expectedUnits {
130 if unitsSchema.Enum[i] != expected {
131 t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
132 }
133 }
134
135 // Check required fields (format should not be required due to omitempty)
136 expectedRequired := []string{"location", "units"}
137 if len(schema.Required) != len(expectedRequired) {
138 t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
139 }
140}
141
142func TestSchemaToParameters(t *testing.T) {
143 schema := Schema{
144 Type: "object",
145 Properties: map[string]*Schema{
146 "name": {
147 Type: "string",
148 Description: "The name field",
149 },
150 "age": {
151 Type: "integer",
152 Minimum: func() *float64 { v := 0.0; return &v }(),
153 Maximum: func() *float64 { v := 120.0; return &v }(),
154 },
155 "tags": {
156 Type: "array",
157 Items: &Schema{
158 Type: "string",
159 },
160 },
161 "priority": {
162 Type: "string",
163 Enum: []any{"low", "medium", "high"},
164 },
165 },
166 Required: []string{"name"},
167 }
168
169 params := schemaToParameters(schema)
170
171 // Check name parameter
172 nameParam, ok := params["name"].(map[string]any)
173 if !ok {
174 t.Fatal("Expected name parameter to exist")
175 }
176 if nameParam["type"] != "string" {
177 t.Errorf("Expected name type 'string', got %v", nameParam["type"])
178 }
179 if nameParam["description"] != "The name field" {
180 t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
181 }
182
183 // Check age parameter with min/max
184 ageParam, ok := params["age"].(map[string]any)
185 if !ok {
186 t.Fatal("Expected age parameter to exist")
187 }
188 if ageParam["type"] != "integer" {
189 t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
190 }
191 if ageParam["minimum"] != 0.0 {
192 t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
193 }
194 if ageParam["maximum"] != 120.0 {
195 t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
196 }
197
198 // Check priority parameter with enum
199 priorityParam, ok := params["priority"].(map[string]any)
200 if !ok {
201 t.Fatal("Expected priority parameter to exist")
202 }
203 if priorityParam["type"] != "string" {
204 t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
205 }
206 enumValues, ok := priorityParam["enum"].([]any)
207 if !ok || len(enumValues) != 3 {
208 t.Errorf("Expected 3 enum values, got %v", enumValues)
209 }
210}
211
212func TestGenerateSchemaBasicTypes(t *testing.T) {
213 t.Parallel()
214
215 tests := []struct {
216 name string
217 input any
218 expected Schema
219 }{
220 {
221 name: "string type",
222 input: "",
223 expected: Schema{Type: "string"},
224 },
225 {
226 name: "int type",
227 input: 0,
228 expected: Schema{Type: "integer"},
229 },
230 {
231 name: "int64 type",
232 input: int64(0),
233 expected: Schema{Type: "integer"},
234 },
235 {
236 name: "uint type",
237 input: uint(0),
238 expected: Schema{Type: "integer"},
239 },
240 {
241 name: "float64 type",
242 input: 0.0,
243 expected: Schema{Type: "number"},
244 },
245 {
246 name: "float32 type",
247 input: float32(0.0),
248 expected: Schema{Type: "number"},
249 },
250 {
251 name: "bool type",
252 input: false,
253 expected: Schema{Type: "boolean"},
254 },
255 }
256
257 for _, tt := range tests {
258 t.Run(tt.name, func(t *testing.T) {
259 t.Parallel()
260 schema := generateSchema(reflect.TypeOf(tt.input))
261 if schema.Type != tt.expected.Type {
262 t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
263 }
264 })
265 }
266}
267
268func TestGenerateSchemaArrayTypes(t *testing.T) {
269 t.Parallel()
270
271 tests := []struct {
272 name string
273 input any
274 expected Schema
275 }{
276 {
277 name: "string slice",
278 input: []string{},
279 expected: Schema{
280 Type: "array",
281 Items: &Schema{Type: "string"},
282 },
283 },
284 {
285 name: "int slice",
286 input: []int{},
287 expected: Schema{
288 Type: "array",
289 Items: &Schema{Type: "integer"},
290 },
291 },
292 {
293 name: "string array",
294 input: [3]string{},
295 expected: Schema{
296 Type: "array",
297 Items: &Schema{Type: "string"},
298 },
299 },
300 }
301
302 for _, tt := range tests {
303 t.Run(tt.name, func(t *testing.T) {
304 t.Parallel()
305 schema := generateSchema(reflect.TypeOf(tt.input))
306 if schema.Type != tt.expected.Type {
307 t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
308 }
309 if schema.Items == nil {
310 t.Fatal("Expected items schema to exist")
311 }
312 if schema.Items.Type != tt.expected.Items.Type {
313 t.Errorf("Expected items type %s, got %s", tt.expected.Items.Type, schema.Items.Type)
314 }
315 })
316 }
317}
318
319func TestGenerateSchemaMapTypes(t *testing.T) {
320 t.Parallel()
321
322 tests := []struct {
323 name string
324 input any
325 expected string
326 }{
327 {
328 name: "string to string map",
329 input: map[string]string{},
330 expected: "object",
331 },
332 {
333 name: "string to int map",
334 input: map[string]int{},
335 expected: "object",
336 },
337 {
338 name: "int to string map",
339 input: map[int]string{},
340 expected: "object",
341 },
342 }
343
344 for _, tt := range tests {
345 t.Run(tt.name, func(t *testing.T) {
346 t.Parallel()
347 schema := generateSchema(reflect.TypeOf(tt.input))
348 if schema.Type != tt.expected {
349 t.Errorf("Expected type %s, got %s", tt.expected, schema.Type)
350 }
351 })
352 }
353}
354
355func TestGenerateSchemaStructTypes(t *testing.T) {
356 t.Parallel()
357
358 type SimpleStruct struct {
359 Name string `json:"name" description:"The name field"`
360 Age int `json:"age"`
361 }
362
363 type StructWithOmitEmpty struct {
364 Required string `json:"required"`
365 Optional string `json:"optional,omitempty"`
366 }
367
368 type StructWithJSONIgnore struct {
369 Visible string `json:"visible"`
370 Hidden string `json:"-"`
371 }
372
373 type StructWithoutJSONTags struct {
374 FirstName string
375 LastName string
376 }
377
378 tests := []struct {
379 name string
380 input any
381 validate func(t *testing.T, schema Schema)
382 }{
383 {
384 name: "simple struct",
385 input: SimpleStruct{},
386 validate: func(t *testing.T, schema Schema) {
387 if schema.Type != "object" {
388 t.Errorf("Expected type object, got %s", schema.Type)
389 }
390 if len(schema.Properties) != 2 {
391 t.Errorf("Expected 2 properties, got %d", len(schema.Properties))
392 }
393 if schema.Properties["name"] == nil {
394 t.Error("Expected name property to exist")
395 }
396 if schema.Properties["name"].Description != "The name field" {
397 t.Errorf("Expected description 'The name field', got %s", schema.Properties["name"].Description)
398 }
399 if len(schema.Required) != 2 {
400 t.Errorf("Expected 2 required fields, got %d", len(schema.Required))
401 }
402 },
403 },
404 {
405 name: "struct with omitempty",
406 input: StructWithOmitEmpty{},
407 validate: func(t *testing.T, schema Schema) {
408 if len(schema.Required) != 1 {
409 t.Errorf("Expected 1 required field, got %d", len(schema.Required))
410 }
411 if schema.Required[0] != "required" {
412 t.Errorf("Expected required field 'required', got %s", schema.Required[0])
413 }
414 },
415 },
416 {
417 name: "struct with json ignore",
418 input: StructWithJSONIgnore{},
419 validate: func(t *testing.T, schema Schema) {
420 if len(schema.Properties) != 1 {
421 t.Errorf("Expected 1 property, got %d", len(schema.Properties))
422 }
423 if schema.Properties["visible"] == nil {
424 t.Error("Expected visible property to exist")
425 }
426 if schema.Properties["hidden"] != nil {
427 t.Error("Expected hidden property to not exist")
428 }
429 },
430 },
431 {
432 name: "struct without json tags",
433 input: StructWithoutJSONTags{},
434 validate: func(t *testing.T, schema Schema) {
435 if schema.Properties["first_name"] == nil {
436 t.Error("Expected first_name property to exist")
437 }
438 if schema.Properties["last_name"] == nil {
439 t.Error("Expected last_name property to exist")
440 }
441 },
442 },
443 }
444
445 for _, tt := range tests {
446 t.Run(tt.name, func(t *testing.T) {
447 t.Parallel()
448 schema := generateSchema(reflect.TypeOf(tt.input))
449 tt.validate(t, schema)
450 })
451 }
452}
453
454func TestGenerateSchemaPointerTypes(t *testing.T) {
455 t.Parallel()
456
457 type StructWithPointers struct {
458 Name *string `json:"name"`
459 Age *int `json:"age"`
460 }
461
462 schema := generateSchema(reflect.TypeOf(StructWithPointers{}))
463
464 if schema.Type != "object" {
465 t.Errorf("Expected type object, got %s", schema.Type)
466 }
467
468 if schema.Properties["name"] == nil {
469 t.Fatal("Expected name property to exist")
470 }
471 if schema.Properties["name"].Type != "string" {
472 t.Errorf("Expected name type string, got %s", schema.Properties["name"].Type)
473 }
474
475 if schema.Properties["age"] == nil {
476 t.Fatal("Expected age property to exist")
477 }
478 if schema.Properties["age"].Type != "integer" {
479 t.Errorf("Expected age type integer, got %s", schema.Properties["age"].Type)
480 }
481}
482
483func TestGenerateSchemaNestedStructs(t *testing.T) {
484 t.Parallel()
485
486 type Address struct {
487 Street string `json:"street"`
488 City string `json:"city"`
489 }
490
491 type Person struct {
492 Name string `json:"name"`
493 Address Address `json:"address"`
494 }
495
496 schema := generateSchema(reflect.TypeOf(Person{}))
497
498 if schema.Type != "object" {
499 t.Errorf("Expected type object, got %s", schema.Type)
500 }
501
502 if schema.Properties["address"] == nil {
503 t.Fatal("Expected address property to exist")
504 }
505
506 addressSchema := schema.Properties["address"]
507 if addressSchema.Type != "object" {
508 t.Errorf("Expected address type object, got %s", addressSchema.Type)
509 }
510
511 if addressSchema.Properties["street"] == nil {
512 t.Error("Expected street property in address to exist")
513 }
514 if addressSchema.Properties["city"] == nil {
515 t.Error("Expected city property in address to exist")
516 }
517}
518
519func TestGenerateSchemaRecursiveStructs(t *testing.T) {
520 t.Parallel()
521
522 type Node struct {
523 Value string `json:"value"`
524 Next *Node `json:"next,omitempty"`
525 }
526
527 schema := generateSchema(reflect.TypeOf(Node{}))
528
529 if schema.Type != "object" {
530 t.Errorf("Expected type object, got %s", schema.Type)
531 }
532
533 if schema.Properties["value"] == nil {
534 t.Error("Expected value property to exist")
535 }
536
537 if schema.Properties["next"] == nil {
538 t.Error("Expected next property to exist")
539 }
540
541 // The recursive reference should be handled gracefully
542 nextSchema := schema.Properties["next"]
543 if nextSchema.Type != "object" {
544 t.Errorf("Expected next type object, got %s", nextSchema.Type)
545 }
546}
547
548func TestGenerateSchemaWithEnumTags(t *testing.T) {
549 t.Parallel()
550
551 type ConfigInput struct {
552 Level string `json:"level" enum:"debug,info,warn,error" description:"Log level"`
553 Format string `json:"format" enum:"json,text"`
554 Optional string `json:"optional,omitempty" enum:"a,b,c"`
555 }
556
557 schema := generateSchema(reflect.TypeOf(ConfigInput{}))
558
559 // Check level field
560 levelSchema := schema.Properties["level"]
561 if levelSchema == nil {
562 t.Fatal("Expected level property to exist")
563 }
564 if len(levelSchema.Enum) != 4 {
565 t.Errorf("Expected 4 enum values for level, got %d", len(levelSchema.Enum))
566 }
567 expectedLevels := []string{"debug", "info", "warn", "error"}
568 for i, expected := range expectedLevels {
569 if levelSchema.Enum[i] != expected {
570 t.Errorf("Expected enum value %s, got %v", expected, levelSchema.Enum[i])
571 }
572 }
573
574 // Check format field
575 formatSchema := schema.Properties["format"]
576 if formatSchema == nil {
577 t.Fatal("Expected format property to exist")
578 }
579 if len(formatSchema.Enum) != 2 {
580 t.Errorf("Expected 2 enum values for format, got %d", len(formatSchema.Enum))
581 }
582
583 // Check required fields (optional should not be required due to omitempty)
584 expectedRequired := []string{"level", "format"}
585 if len(schema.Required) != len(expectedRequired) {
586 t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
587 }
588}
589
590func TestGenerateSchemaComplexTypes(t *testing.T) {
591 t.Parallel()
592
593 type ComplexInput struct {
594 StringSlice []string `json:"string_slice"`
595 IntMap map[string]int `json:"int_map"`
596 NestedSlice []map[string]string `json:"nested_slice"`
597 Interface any `json:"interface"`
598 }
599
600 schema := generateSchema(reflect.TypeOf(ComplexInput{}))
601
602 // Check string slice
603 stringSliceSchema := schema.Properties["string_slice"]
604 if stringSliceSchema == nil {
605 t.Fatal("Expected string_slice property to exist")
606 }
607 if stringSliceSchema.Type != "array" {
608 t.Errorf("Expected string_slice type array, got %s", stringSliceSchema.Type)
609 }
610 if stringSliceSchema.Items.Type != "string" {
611 t.Errorf("Expected string_slice items type string, got %s", stringSliceSchema.Items.Type)
612 }
613
614 // Check int map
615 intMapSchema := schema.Properties["int_map"]
616 if intMapSchema == nil {
617 t.Fatal("Expected int_map property to exist")
618 }
619 if intMapSchema.Type != "object" {
620 t.Errorf("Expected int_map type object, got %s", intMapSchema.Type)
621 }
622
623 // Check nested slice
624 nestedSliceSchema := schema.Properties["nested_slice"]
625 if nestedSliceSchema == nil {
626 t.Fatal("Expected nested_slice property to exist")
627 }
628 if nestedSliceSchema.Type != "array" {
629 t.Errorf("Expected nested_slice type array, got %s", nestedSliceSchema.Type)
630 }
631 if nestedSliceSchema.Items.Type != "object" {
632 t.Errorf("Expected nested_slice items type object, got %s", nestedSliceSchema.Items.Type)
633 }
634
635 // Check interface
636 interfaceSchema := schema.Properties["interface"]
637 if interfaceSchema == nil {
638 t.Fatal("Expected interface property to exist")
639 }
640 if interfaceSchema.Type != "object" {
641 t.Errorf("Expected interface type object, got %s", interfaceSchema.Type)
642 }
643}
644
645func TestToSnakeCase(t *testing.T) {
646 t.Parallel()
647
648 tests := []struct {
649 input string
650 expected string
651 }{
652 {"FirstName", "first_name"},
653 {"XMLHttpRequest", "x_m_l_http_request"},
654 {"ID", "i_d"},
655 {"HTTPSProxy", "h_t_t_p_s_proxy"},
656 {"simple", "simple"},
657 {"", ""},
658 {"A", "a"},
659 {"AB", "a_b"},
660 {"CamelCase", "camel_case"},
661 }
662
663 for _, tt := range tests {
664 t.Run(tt.input, func(t *testing.T) {
665 t.Parallel()
666 result := toSnakeCase(tt.input)
667 if result != tt.expected {
668 t.Errorf("toSnakeCase(%s) = %s, expected %s", tt.input, result, tt.expected)
669 }
670 })
671 }
672}
673
674func TestSchemaToParametersEdgeCases(t *testing.T) {
675 t.Parallel()
676
677 tests := []struct {
678 name string
679 schema Schema
680 expected map[string]any
681 }{
682 {
683 name: "non-object schema",
684 schema: Schema{
685 Type: "string",
686 },
687 expected: map[string]any{},
688 },
689 {
690 name: "object with no properties",
691 schema: Schema{
692 Type: "object",
693 Properties: nil,
694 },
695 expected: map[string]any{},
696 },
697 {
698 name: "object with empty properties",
699 schema: Schema{
700 Type: "object",
701 Properties: map[string]*Schema{},
702 },
703 expected: map[string]any{},
704 },
705 {
706 name: "schema with all constraint types",
707 schema: Schema{
708 Type: "object",
709 Properties: map[string]*Schema{
710 "text": {
711 Type: "string",
712 Format: "email",
713 MinLength: func() *int { v := 5; return &v }(),
714 MaxLength: func() *int { v := 100; return &v }(),
715 },
716 "number": {
717 Type: "number",
718 Minimum: func() *float64 { v := 0.0; return &v }(),
719 Maximum: func() *float64 { v := 100.0; return &v }(),
720 },
721 },
722 },
723 expected: map[string]any{
724 "text": map[string]any{
725 "type": "string",
726 "format": "email",
727 "minLength": 5,
728 "maxLength": 100,
729 },
730 "number": map[string]any{
731 "type": "number",
732 "minimum": 0.0,
733 "maximum": 100.0,
734 },
735 },
736 },
737 }
738
739 for _, tt := range tests {
740 t.Run(tt.name, func(t *testing.T) {
741 t.Parallel()
742 result := schemaToParameters(tt.schema)
743 if len(result) != len(tt.expected) {
744 t.Errorf("Expected %d parameters, got %d", len(tt.expected), len(result))
745 }
746 for key, expectedValue := range tt.expected {
747 if result[key] == nil {
748 t.Errorf("Expected parameter %s to exist", key)
749 continue
750 }
751 // Deep comparison would be complex, so we'll check key properties
752 resultParam := result[key].(map[string]any)
753 expectedParam := expectedValue.(map[string]any)
754 for propKey, propValue := range expectedParam {
755 if resultParam[propKey] != propValue {
756 t.Errorf("Expected %s.%s = %v, got %v", key, propKey, propValue, resultParam[propKey])
757 }
758 }
759 }
760 })
761 }
762}