1package ai
  2
  3import (
  4	"context"
  5	"fmt"
  6	"reflect"
  7	"strings"
  8	"testing"
  9)
 10
 11// Example of a simple typed tool using the function approach
 12type CalculatorInput struct {
 13	Expression string `json:"expression" description:"Mathematical expression to evaluate"`
 14}
 15
 16func TestTypedToolFuncExample(t *testing.T) {
 17	// Create a typed tool using the function API
 18	tool := NewAgentTool(
 19		"calculator",
 20		"Evaluates simple mathematical expressions",
 21		func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
 22			if input.Expression == "2+2" {
 23				return NewTextResponse("4"), nil
 24			}
 25			return NewTextErrorResponse("unsupported expression"), nil
 26		},
 27	)
 28
 29	// Check the tool info
 30	info := tool.Info()
 31	if info.Name != "calculator" {
 32		t.Errorf("Expected tool name 'calculator', got %s", info.Name)
 33	}
 34	if len(info.Required) != 1 || info.Required[0] != "expression" {
 35		t.Errorf("Expected required field 'expression', got %v", info.Required)
 36	}
 37
 38	// Test execution
 39	call := ToolCall{
 40		ID:    "test-1",
 41		Name:  "calculator",
 42		Input: `{"expression": "2+2"}`,
 43	}
 44
 45	result, err := tool.Run(context.Background(), call)
 46	if err != nil {
 47		t.Errorf("Unexpected error: %v", err)
 48	}
 49	if result.Content != "4" {
 50		t.Errorf("Expected result '4', got %s", result.Content)
 51	}
 52	if result.IsError {
 53		t.Errorf("Expected successful result, got error")
 54	}
 55}
 56
 57func TestEnumToolExample(t *testing.T) {
 58	type WeatherInput struct {
 59		Location string `json:"location" description:"City name"`
 60		Units    string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
 61	}
 62
 63	// Create a weather tool with enum support
 64	tool := NewAgentTool(
 65		"weather",
 66		"Gets current weather for a location",
 67		func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
 68			temp := "22°C"
 69			if input.Units == "fahrenheit" {
 70				temp = "72°F"
 71			}
 72			return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
 73		},
 74	)
 75
 76	// Check that the schema includes enum values
 77	info := tool.Info()
 78	unitsParam, ok := info.Parameters["units"].(map[string]any)
 79	if !ok {
 80		t.Fatal("Expected units parameter to exist")
 81	}
 82	enumValues, ok := unitsParam["enum"].([]any)
 83	if !ok || len(enumValues) != 2 {
 84		t.Errorf("Expected 2 enum values, got %v", enumValues)
 85	}
 86
 87	// Test execution with enum value
 88	call := ToolCall{
 89		ID:    "test-2",
 90		Name:  "weather",
 91		Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
 92	}
 93
 94	result, err := tool.Run(context.Background(), call)
 95	if err != nil {
 96		t.Errorf("Unexpected error: %v", err)
 97	}
 98	if !strings.Contains(result.Content, "San Francisco") {
 99		t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
100	}
101	if !strings.Contains(result.Content, "72°F") {
102		t.Errorf("Expected result to contain '72°F', got %s", result.Content)
103	}
104}
105
106func TestEnumSupport(t *testing.T) {
107	// Test enum via struct tags
108	type WeatherInput struct {
109		Location string `json:"location" description:"City name"`
110		Units    string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
111		Format   string `json:"format,omitempty" enum:"json,xml,text"`
112	}
113
114	schema := generateSchema(reflect.TypeOf(WeatherInput{}))
115
116	if schema.Type != "object" {
117		t.Errorf("Expected schema type 'object', got %s", schema.Type)
118	}
119
120	// Check units field has enum values
121	unitsSchema := schema.Properties["units"]
122	if unitsSchema == nil {
123		t.Fatal("Expected units property to exist")
124	}
125	if len(unitsSchema.Enum) != 3 {
126		t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
127	}
128	expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
129	for i, expected := range expectedUnits {
130		if unitsSchema.Enum[i] != expected {
131			t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
132		}
133	}
134
135	// Check required fields (format should not be required due to omitempty)
136	expectedRequired := []string{"location", "units"}
137	if len(schema.Required) != len(expectedRequired) {
138		t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
139	}
140}
141
142func TestSchemaToParameters(t *testing.T) {
143	schema := Schema{
144		Type: "object",
145		Properties: map[string]*Schema{
146			"name": {
147				Type:        "string",
148				Description: "The name field",
149			},
150			"age": {
151				Type:    "integer",
152				Minimum: func() *float64 { v := 0.0; return &v }(),
153				Maximum: func() *float64 { v := 120.0; return &v }(),
154			},
155			"tags": {
156				Type: "array",
157				Items: &Schema{
158					Type: "string",
159				},
160			},
161			"priority": {
162				Type: "string",
163				Enum: []any{"low", "medium", "high"},
164			},
165		},
166		Required: []string{"name"},
167	}
168
169	params := schemaToParameters(schema)
170
171	// Check name parameter
172	nameParam, ok := params["name"].(map[string]any)
173	if !ok {
174		t.Fatal("Expected name parameter to exist")
175	}
176	if nameParam["type"] != "string" {
177		t.Errorf("Expected name type 'string', got %v", nameParam["type"])
178	}
179	if nameParam["description"] != "The name field" {
180		t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
181	}
182
183	// Check age parameter with min/max
184	ageParam, ok := params["age"].(map[string]any)
185	if !ok {
186		t.Fatal("Expected age parameter to exist")
187	}
188	if ageParam["type"] != "integer" {
189		t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
190	}
191	if ageParam["minimum"] != 0.0 {
192		t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
193	}
194	if ageParam["maximum"] != 120.0 {
195		t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
196	}
197
198	// Check priority parameter with enum
199	priorityParam, ok := params["priority"].(map[string]any)
200	if !ok {
201		t.Fatal("Expected priority parameter to exist")
202	}
203	if priorityParam["type"] != "string" {
204		t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
205	}
206	enumValues, ok := priorityParam["enum"].([]any)
207	if !ok || len(enumValues) != 3 {
208		t.Errorf("Expected 3 enum values, got %v", enumValues)
209	}
210}
211
212func TestGenerateSchemaBasicTypes(t *testing.T) {
213	t.Parallel()
214
215	tests := []struct {
216		name     string
217		input    any
218		expected Schema
219	}{
220		{
221			name:     "string type",
222			input:    "",
223			expected: Schema{Type: "string"},
224		},
225		{
226			name:     "int type",
227			input:    0,
228			expected: Schema{Type: "integer"},
229		},
230		{
231			name:     "int64 type",
232			input:    int64(0),
233			expected: Schema{Type: "integer"},
234		},
235		{
236			name:     "uint type",
237			input:    uint(0),
238			expected: Schema{Type: "integer"},
239		},
240		{
241			name:     "float64 type",
242			input:    0.0,
243			expected: Schema{Type: "number"},
244		},
245		{
246			name:     "float32 type",
247			input:    float32(0.0),
248			expected: Schema{Type: "number"},
249		},
250		{
251			name:     "bool type",
252			input:    false,
253			expected: Schema{Type: "boolean"},
254		},
255	}
256
257	for _, tt := range tests {
258		t.Run(tt.name, func(t *testing.T) {
259			t.Parallel()
260			schema := generateSchema(reflect.TypeOf(tt.input))
261			if schema.Type != tt.expected.Type {
262				t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
263			}
264		})
265	}
266}
267
268func TestGenerateSchemaArrayTypes(t *testing.T) {
269	t.Parallel()
270
271	tests := []struct {
272		name     string
273		input    any
274		expected Schema
275	}{
276		{
277			name:  "string slice",
278			input: []string{},
279			expected: Schema{
280				Type:  "array",
281				Items: &Schema{Type: "string"},
282			},
283		},
284		{
285			name:  "int slice",
286			input: []int{},
287			expected: Schema{
288				Type:  "array",
289				Items: &Schema{Type: "integer"},
290			},
291		},
292		{
293			name:  "string array",
294			input: [3]string{},
295			expected: Schema{
296				Type:  "array",
297				Items: &Schema{Type: "string"},
298			},
299		},
300	}
301
302	for _, tt := range tests {
303		t.Run(tt.name, func(t *testing.T) {
304			t.Parallel()
305			schema := generateSchema(reflect.TypeOf(tt.input))
306			if schema.Type != tt.expected.Type {
307				t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
308			}
309			if schema.Items == nil {
310				t.Fatal("Expected items schema to exist")
311			}
312			if schema.Items.Type != tt.expected.Items.Type {
313				t.Errorf("Expected items type %s, got %s", tt.expected.Items.Type, schema.Items.Type)
314			}
315		})
316	}
317}
318
319func TestGenerateSchemaMapTypes(t *testing.T) {
320	t.Parallel()
321
322	tests := []struct {
323		name     string
324		input    any
325		expected string
326	}{
327		{
328			name:     "string to string map",
329			input:    map[string]string{},
330			expected: "object",
331		},
332		{
333			name:     "string to int map",
334			input:    map[string]int{},
335			expected: "object",
336		},
337		{
338			name:     "int to string map",
339			input:    map[int]string{},
340			expected: "object",
341		},
342	}
343
344	for _, tt := range tests {
345		t.Run(tt.name, func(t *testing.T) {
346			t.Parallel()
347			schema := generateSchema(reflect.TypeOf(tt.input))
348			if schema.Type != tt.expected {
349				t.Errorf("Expected type %s, got %s", tt.expected, schema.Type)
350			}
351		})
352	}
353}
354
355func TestGenerateSchemaStructTypes(t *testing.T) {
356	t.Parallel()
357
358	type SimpleStruct struct {
359		Name string `json:"name" description:"The name field"`
360		Age  int    `json:"age"`
361	}
362
363	type StructWithOmitEmpty struct {
364		Required string `json:"required"`
365		Optional string `json:"optional,omitempty"`
366	}
367
368	type StructWithJSONIgnore struct {
369		Visible string `json:"visible"`
370		Hidden  string `json:"-"`
371	}
372
373	type StructWithoutJSONTags struct {
374		FirstName string
375		LastName  string
376	}
377
378	tests := []struct {
379		name     string
380		input    any
381		validate func(t *testing.T, schema Schema)
382	}{
383		{
384			name:  "simple struct",
385			input: SimpleStruct{},
386			validate: func(t *testing.T, schema Schema) {
387				if schema.Type != "object" {
388					t.Errorf("Expected type object, got %s", schema.Type)
389				}
390				if len(schema.Properties) != 2 {
391					t.Errorf("Expected 2 properties, got %d", len(schema.Properties))
392				}
393				if schema.Properties["name"] == nil {
394					t.Error("Expected name property to exist")
395				}
396				if schema.Properties["name"].Description != "The name field" {
397					t.Errorf("Expected description 'The name field', got %s", schema.Properties["name"].Description)
398				}
399				if len(schema.Required) != 2 {
400					t.Errorf("Expected 2 required fields, got %d", len(schema.Required))
401				}
402			},
403		},
404		{
405			name:  "struct with omitempty",
406			input: StructWithOmitEmpty{},
407			validate: func(t *testing.T, schema Schema) {
408				if len(schema.Required) != 1 {
409					t.Errorf("Expected 1 required field, got %d", len(schema.Required))
410				}
411				if schema.Required[0] != "required" {
412					t.Errorf("Expected required field 'required', got %s", schema.Required[0])
413				}
414			},
415		},
416		{
417			name:  "struct with json ignore",
418			input: StructWithJSONIgnore{},
419			validate: func(t *testing.T, schema Schema) {
420				if len(schema.Properties) != 1 {
421					t.Errorf("Expected 1 property, got %d", len(schema.Properties))
422				}
423				if schema.Properties["visible"] == nil {
424					t.Error("Expected visible property to exist")
425				}
426				if schema.Properties["hidden"] != nil {
427					t.Error("Expected hidden property to not exist")
428				}
429			},
430		},
431		{
432			name:  "struct without json tags",
433			input: StructWithoutJSONTags{},
434			validate: func(t *testing.T, schema Schema) {
435				if schema.Properties["first_name"] == nil {
436					t.Error("Expected first_name property to exist")
437				}
438				if schema.Properties["last_name"] == nil {
439					t.Error("Expected last_name property to exist")
440				}
441			},
442		},
443	}
444
445	for _, tt := range tests {
446		t.Run(tt.name, func(t *testing.T) {
447			t.Parallel()
448			schema := generateSchema(reflect.TypeOf(tt.input))
449			tt.validate(t, schema)
450		})
451	}
452}
453
454func TestGenerateSchemaPointerTypes(t *testing.T) {
455	t.Parallel()
456
457	type StructWithPointers struct {
458		Name *string `json:"name"`
459		Age  *int    `json:"age"`
460	}
461
462	schema := generateSchema(reflect.TypeOf(StructWithPointers{}))
463
464	if schema.Type != "object" {
465		t.Errorf("Expected type object, got %s", schema.Type)
466	}
467
468	if schema.Properties["name"] == nil {
469		t.Fatal("Expected name property to exist")
470	}
471	if schema.Properties["name"].Type != "string" {
472		t.Errorf("Expected name type string, got %s", schema.Properties["name"].Type)
473	}
474
475	if schema.Properties["age"] == nil {
476		t.Fatal("Expected age property to exist")
477	}
478	if schema.Properties["age"].Type != "integer" {
479		t.Errorf("Expected age type integer, got %s", schema.Properties["age"].Type)
480	}
481}
482
483func TestGenerateSchemaNestedStructs(t *testing.T) {
484	t.Parallel()
485
486	type Address struct {
487		Street string `json:"street"`
488		City   string `json:"city"`
489	}
490
491	type Person struct {
492		Name    string  `json:"name"`
493		Address Address `json:"address"`
494	}
495
496	schema := generateSchema(reflect.TypeOf(Person{}))
497
498	if schema.Type != "object" {
499		t.Errorf("Expected type object, got %s", schema.Type)
500	}
501
502	if schema.Properties["address"] == nil {
503		t.Fatal("Expected address property to exist")
504	}
505
506	addressSchema := schema.Properties["address"]
507	if addressSchema.Type != "object" {
508		t.Errorf("Expected address type object, got %s", addressSchema.Type)
509	}
510
511	if addressSchema.Properties["street"] == nil {
512		t.Error("Expected street property in address to exist")
513	}
514	if addressSchema.Properties["city"] == nil {
515		t.Error("Expected city property in address to exist")
516	}
517}
518
519func TestGenerateSchemaRecursiveStructs(t *testing.T) {
520	t.Parallel()
521
522	type Node struct {
523		Value string `json:"value"`
524		Next  *Node  `json:"next,omitempty"`
525	}
526
527	schema := generateSchema(reflect.TypeOf(Node{}))
528
529	if schema.Type != "object" {
530		t.Errorf("Expected type object, got %s", schema.Type)
531	}
532
533	if schema.Properties["value"] == nil {
534		t.Error("Expected value property to exist")
535	}
536
537	if schema.Properties["next"] == nil {
538		t.Error("Expected next property to exist")
539	}
540
541	// The recursive reference should be handled gracefully
542	nextSchema := schema.Properties["next"]
543	if nextSchema.Type != "object" {
544		t.Errorf("Expected next type object, got %s", nextSchema.Type)
545	}
546}
547
548func TestGenerateSchemaWithEnumTags(t *testing.T) {
549	t.Parallel()
550
551	type ConfigInput struct {
552		Level    string `json:"level" enum:"debug,info,warn,error" description:"Log level"`
553		Format   string `json:"format" enum:"json,text"`
554		Optional string `json:"optional,omitempty" enum:"a,b,c"`
555	}
556
557	schema := generateSchema(reflect.TypeOf(ConfigInput{}))
558
559	// Check level field
560	levelSchema := schema.Properties["level"]
561	if levelSchema == nil {
562		t.Fatal("Expected level property to exist")
563	}
564	if len(levelSchema.Enum) != 4 {
565		t.Errorf("Expected 4 enum values for level, got %d", len(levelSchema.Enum))
566	}
567	expectedLevels := []string{"debug", "info", "warn", "error"}
568	for i, expected := range expectedLevels {
569		if levelSchema.Enum[i] != expected {
570			t.Errorf("Expected enum value %s, got %v", expected, levelSchema.Enum[i])
571		}
572	}
573
574	// Check format field
575	formatSchema := schema.Properties["format"]
576	if formatSchema == nil {
577		t.Fatal("Expected format property to exist")
578	}
579	if len(formatSchema.Enum) != 2 {
580		t.Errorf("Expected 2 enum values for format, got %d", len(formatSchema.Enum))
581	}
582
583	// Check required fields (optional should not be required due to omitempty)
584	expectedRequired := []string{"level", "format"}
585	if len(schema.Required) != len(expectedRequired) {
586		t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
587	}
588}
589
590func TestGenerateSchemaComplexTypes(t *testing.T) {
591	t.Parallel()
592
593	type ComplexInput struct {
594		StringSlice []string            `json:"string_slice"`
595		IntMap      map[string]int      `json:"int_map"`
596		NestedSlice []map[string]string `json:"nested_slice"`
597		Interface   any                 `json:"interface"`
598	}
599
600	schema := generateSchema(reflect.TypeOf(ComplexInput{}))
601
602	// Check string slice
603	stringSliceSchema := schema.Properties["string_slice"]
604	if stringSliceSchema == nil {
605		t.Fatal("Expected string_slice property to exist")
606	}
607	if stringSliceSchema.Type != "array" {
608		t.Errorf("Expected string_slice type array, got %s", stringSliceSchema.Type)
609	}
610	if stringSliceSchema.Items.Type != "string" {
611		t.Errorf("Expected string_slice items type string, got %s", stringSliceSchema.Items.Type)
612	}
613
614	// Check int map
615	intMapSchema := schema.Properties["int_map"]
616	if intMapSchema == nil {
617		t.Fatal("Expected int_map property to exist")
618	}
619	if intMapSchema.Type != "object" {
620		t.Errorf("Expected int_map type object, got %s", intMapSchema.Type)
621	}
622
623	// Check nested slice
624	nestedSliceSchema := schema.Properties["nested_slice"]
625	if nestedSliceSchema == nil {
626		t.Fatal("Expected nested_slice property to exist")
627	}
628	if nestedSliceSchema.Type != "array" {
629		t.Errorf("Expected nested_slice type array, got %s", nestedSliceSchema.Type)
630	}
631	if nestedSliceSchema.Items.Type != "object" {
632		t.Errorf("Expected nested_slice items type object, got %s", nestedSliceSchema.Items.Type)
633	}
634
635	// Check interface
636	interfaceSchema := schema.Properties["interface"]
637	if interfaceSchema == nil {
638		t.Fatal("Expected interface property to exist")
639	}
640	if interfaceSchema.Type != "object" {
641		t.Errorf("Expected interface type object, got %s", interfaceSchema.Type)
642	}
643}
644
645func TestToSnakeCase(t *testing.T) {
646	t.Parallel()
647
648	tests := []struct {
649		input    string
650		expected string
651	}{
652		{"FirstName", "first_name"},
653		{"XMLHttpRequest", "x_m_l_http_request"},
654		{"ID", "i_d"},
655		{"HTTPSProxy", "h_t_t_p_s_proxy"},
656		{"simple", "simple"},
657		{"", ""},
658		{"A", "a"},
659		{"AB", "a_b"},
660		{"CamelCase", "camel_case"},
661	}
662
663	for _, tt := range tests {
664		t.Run(tt.input, func(t *testing.T) {
665			t.Parallel()
666			result := toSnakeCase(tt.input)
667			if result != tt.expected {
668				t.Errorf("toSnakeCase(%s) = %s, expected %s", tt.input, result, tt.expected)
669			}
670		})
671	}
672}
673
674func TestSchemaToParametersEdgeCases(t *testing.T) {
675	t.Parallel()
676
677	tests := []struct {
678		name     string
679		schema   Schema
680		expected map[string]any
681	}{
682		{
683			name: "non-object schema",
684			schema: Schema{
685				Type: "string",
686			},
687			expected: map[string]any{},
688		},
689		{
690			name: "object with no properties",
691			schema: Schema{
692				Type:       "object",
693				Properties: nil,
694			},
695			expected: map[string]any{},
696		},
697		{
698			name: "object with empty properties",
699			schema: Schema{
700				Type:       "object",
701				Properties: map[string]*Schema{},
702			},
703			expected: map[string]any{},
704		},
705		{
706			name: "schema with all constraint types",
707			schema: Schema{
708				Type: "object",
709				Properties: map[string]*Schema{
710					"text": {
711						Type:      "string",
712						Format:    "email",
713						MinLength: func() *int { v := 5; return &v }(),
714						MaxLength: func() *int { v := 100; return &v }(),
715					},
716					"number": {
717						Type:    "number",
718						Minimum: func() *float64 { v := 0.0; return &v }(),
719						Maximum: func() *float64 { v := 100.0; return &v }(),
720					},
721				},
722			},
723			expected: map[string]any{
724				"text": map[string]any{
725					"type":      "string",
726					"format":    "email",
727					"minLength": 5,
728					"maxLength": 100,
729				},
730				"number": map[string]any{
731					"type":    "number",
732					"minimum": 0.0,
733					"maximum": 100.0,
734				},
735			},
736		},
737	}
738
739	for _, tt := range tests {
740		t.Run(tt.name, func(t *testing.T) {
741			t.Parallel()
742			result := schemaToParameters(tt.schema)
743			if len(result) != len(tt.expected) {
744				t.Errorf("Expected %d parameters, got %d", len(tt.expected), len(result))
745			}
746			for key, expectedValue := range tt.expected {
747				if result[key] == nil {
748					t.Errorf("Expected parameter %s to exist", key)
749					continue
750				}
751				// Deep comparison would be complex, so we'll check key properties
752				resultParam := result[key].(map[string]any)
753				expectedParam := expectedValue.(map[string]any)
754				for propKey, propValue := range expectedParam {
755					if resultParam[propKey] != propValue {
756						t.Errorf("Expected %s.%s = %v, got %v", key, propKey, propValue, resultParam[propKey])
757					}
758				}
759			}
760		})
761	}
762}