tool_test.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"fmt"
  6	"reflect"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// Example of a simple typed tool using the function approach
 13type CalculatorInput struct {
 14	Expression string `json:"expression" description:"Mathematical expression to evaluate"`
 15}
 16
 17func TestTypedToolFuncExample(t *testing.T) {
 18	// Create a typed tool using the function API
 19	tool := NewAgentTool(
 20		"calculator",
 21		"Evaluates simple mathematical expressions",
 22		func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
 23			if input.Expression == "2+2" {
 24				return NewTextResponse("4"), nil
 25			}
 26			return NewTextErrorResponse("unsupported expression"), nil
 27		},
 28	)
 29
 30	// Check the tool info
 31	info := tool.Info()
 32	require.Equal(t, "calculator", info.Name)
 33	require.Len(t, info.Required, 1)
 34	require.Equal(t, "expression", info.Required[0])
 35
 36	// Test execution
 37	call := ToolCall{
 38		ID:    "test-1",
 39		Name:  "calculator",
 40		Input: `{"expression": "2+2"}`,
 41	}
 42
 43	result, err := tool.Run(context.Background(), call)
 44	require.NoError(t, err)
 45	require.Equal(t, "4", result.Content)
 46	require.False(t, result.IsError)
 47}
 48
 49func TestEnumToolExample(t *testing.T) {
 50	type WeatherInput struct {
 51		Location string `json:"location" description:"City name"`
 52		Units    string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
 53	}
 54
 55	// Create a weather tool with enum support
 56	tool := NewAgentTool(
 57		"weather",
 58		"Gets current weather for a location",
 59		func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
 60			temp := "22°C"
 61			if input.Units == "fahrenheit" {
 62				temp = "72°F"
 63			}
 64			return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
 65		},
 66	)
 67
 68	// Check that the schema includes enum values
 69	info := tool.Info()
 70	unitsParam, ok := info.Parameters["units"].(map[string]any)
 71	require.True(t, ok, "Expected units parameter to exist")
 72	enumValues, ok := unitsParam["enum"].([]any)
 73	require.True(t, ok)
 74	require.Len(t, enumValues, 2)
 75
 76	// Test execution with enum value
 77	call := ToolCall{
 78		ID:    "test-2",
 79		Name:  "weather",
 80		Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
 81	}
 82
 83	result, err := tool.Run(context.Background(), call)
 84	require.NoError(t, err)
 85	require.Contains(t, result.Content, "San Francisco")
 86	require.Contains(t, result.Content, "72°F")
 87}
 88
 89func TestEnumSupport(t *testing.T) {
 90	// Test enum via struct tags
 91	type WeatherInput struct {
 92		Location string `json:"location" description:"City name"`
 93		Units    string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
 94		Format   string `json:"format,omitempty" enum:"json,xml,text"`
 95	}
 96
 97	schema := generateSchema(reflect.TypeOf(WeatherInput{}))
 98
 99	require.Equal(t, "object", schema.Type)
100
101	// Check units field has enum values
102	unitsSchema := schema.Properties["units"]
103	require.NotNil(t, unitsSchema, "Expected units property to exist")
104	require.Len(t, unitsSchema.Enum, 3)
105	expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
106	for i, expected := range expectedUnits {
107		require.Equal(t, expected, unitsSchema.Enum[i])
108	}
109
110	// Check required fields (format should not be required due to omitempty)
111	expectedRequired := []string{"location", "units"}
112	require.Len(t, schema.Required, len(expectedRequired))
113}
114
115func TestSchemaToParameters(t *testing.T) {
116	schema := Schema{
117		Type: "object",
118		Properties: map[string]*Schema{
119			"name": {
120				Type:        "string",
121				Description: "The name field",
122			},
123			"age": {
124				Type:    "integer",
125				Minimum: func() *float64 { v := 0.0; return &v }(),
126				Maximum: func() *float64 { v := 120.0; return &v }(),
127			},
128			"tags": {
129				Type: "array",
130				Items: &Schema{
131					Type: "string",
132				},
133			},
134			"priority": {
135				Type: "string",
136				Enum: []any{"low", "medium", "high"},
137			},
138		},
139		Required: []string{"name"},
140	}
141
142	params := schemaToParameters(schema)
143
144	// Check name parameter
145	nameParam, ok := params["name"].(map[string]any)
146	require.True(t, ok, "Expected name parameter to exist")
147	require.Equal(t, "string", nameParam["type"])
148	require.Equal(t, "The name field", nameParam["description"])
149
150	// Check age parameter with min/max
151	ageParam, ok := params["age"].(map[string]any)
152	require.True(t, ok, "Expected age parameter to exist")
153	require.Equal(t, "integer", ageParam["type"])
154	require.Equal(t, 0.0, ageParam["minimum"])
155	require.Equal(t, 120.0, ageParam["maximum"])
156
157	// Check priority parameter with enum
158	priorityParam, ok := params["priority"].(map[string]any)
159	require.True(t, ok, "Expected priority parameter to exist")
160	require.Equal(t, "string", priorityParam["type"])
161	enumValues, ok := priorityParam["enum"].([]any)
162	require.True(t, ok)
163	require.Len(t, enumValues, 3)
164}
165
166func TestGenerateSchemaBasicTypes(t *testing.T) {
167	t.Parallel()
168
169	tests := []struct {
170		name     string
171		input    any
172		expected Schema
173	}{
174		{
175			name:     "string type",
176			input:    "",
177			expected: Schema{Type: "string"},
178		},
179		{
180			name:     "int type",
181			input:    0,
182			expected: Schema{Type: "integer"},
183		},
184		{
185			name:     "int64 type",
186			input:    int64(0),
187			expected: Schema{Type: "integer"},
188		},
189		{
190			name:     "uint type",
191			input:    uint(0),
192			expected: Schema{Type: "integer"},
193		},
194		{
195			name:     "float64 type",
196			input:    0.0,
197			expected: Schema{Type: "number"},
198		},
199		{
200			name:     "float32 type",
201			input:    float32(0.0),
202			expected: Schema{Type: "number"},
203		},
204		{
205			name:     "bool type",
206			input:    false,
207			expected: Schema{Type: "boolean"},
208		},
209	}
210
211	for _, tt := range tests {
212		t.Run(tt.name, func(t *testing.T) {
213			t.Parallel()
214			schema := generateSchema(reflect.TypeOf(tt.input))
215			require.Equal(t, tt.expected.Type, schema.Type)
216		})
217	}
218}
219
220func TestGenerateSchemaArrayTypes(t *testing.T) {
221	t.Parallel()
222
223	tests := []struct {
224		name     string
225		input    any
226		expected Schema
227	}{
228		{
229			name:  "string slice",
230			input: []string{},
231			expected: Schema{
232				Type:  "array",
233				Items: &Schema{Type: "string"},
234			},
235		},
236		{
237			name:  "int slice",
238			input: []int{},
239			expected: Schema{
240				Type:  "array",
241				Items: &Schema{Type: "integer"},
242			},
243		},
244		{
245			name:  "string array",
246			input: [3]string{},
247			expected: Schema{
248				Type:  "array",
249				Items: &Schema{Type: "string"},
250			},
251		},
252	}
253
254	for _, tt := range tests {
255		t.Run(tt.name, func(t *testing.T) {
256			t.Parallel()
257			schema := generateSchema(reflect.TypeOf(tt.input))
258			require.Equal(t, tt.expected.Type, schema.Type)
259			require.NotNil(t, schema.Items, "Expected items schema to exist")
260			require.Equal(t, tt.expected.Items.Type, schema.Items.Type)
261		})
262	}
263}
264
265func TestGenerateSchemaMapTypes(t *testing.T) {
266	t.Parallel()
267
268	tests := []struct {
269		name     string
270		input    any
271		expected string
272	}{
273		{
274			name:     "string to string map",
275			input:    map[string]string{},
276			expected: "object",
277		},
278		{
279			name:     "string to int map",
280			input:    map[string]int{},
281			expected: "object",
282		},
283		{
284			name:     "int to string map",
285			input:    map[int]string{},
286			expected: "object",
287		},
288	}
289
290	for _, tt := range tests {
291		t.Run(tt.name, func(t *testing.T) {
292			t.Parallel()
293			schema := generateSchema(reflect.TypeOf(tt.input))
294			require.Equal(t, tt.expected, schema.Type)
295		})
296	}
297}
298
299func TestGenerateSchemaStructTypes(t *testing.T) {
300	t.Parallel()
301
302	type SimpleStruct struct {
303		Name string `json:"name" description:"The name field"`
304		Age  int    `json:"age"`
305	}
306
307	type StructWithOmitEmpty struct {
308		Required string `json:"required"`
309		Optional string `json:"optional,omitempty"`
310	}
311
312	type StructWithJSONIgnore struct {
313		Visible string `json:"visible"`
314		Hidden  string `json:"-"`
315	}
316
317	type StructWithoutJSONTags struct {
318		FirstName string
319		LastName  string
320	}
321
322	tests := []struct {
323		name     string
324		input    any
325		validate func(t *testing.T, schema Schema)
326	}{
327		{
328			name:  "simple struct",
329			input: SimpleStruct{},
330			validate: func(t *testing.T, schema Schema) {
331				require.Equal(t, "object", schema.Type)
332				require.Len(t, schema.Properties, 2)
333				require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
334				require.Equal(t, "The name field", schema.Properties["name"].Description)
335				require.Len(t, schema.Required, 2)
336			},
337		},
338		{
339			name:  "struct with omitempty",
340			input: StructWithOmitEmpty{},
341			validate: func(t *testing.T, schema Schema) {
342				require.Len(t, schema.Required, 1)
343				require.Equal(t, "required", schema.Required[0])
344			},
345		},
346		{
347			name:  "struct with json ignore",
348			input: StructWithJSONIgnore{},
349			validate: func(t *testing.T, schema Schema) {
350				require.Len(t, schema.Properties, 1)
351				require.NotNil(t, schema.Properties["visible"], "Expected visible property to exist")
352				require.Nil(t, schema.Properties["hidden"], "Expected hidden property to not exist")
353			},
354		},
355		{
356			name:  "struct without json tags",
357			input: StructWithoutJSONTags{},
358			validate: func(t *testing.T, schema Schema) {
359				require.NotNil(t, schema.Properties["first_name"], "Expected first_name property to exist")
360				require.NotNil(t, schema.Properties["last_name"], "Expected last_name property to exist")
361			},
362		},
363	}
364
365	for _, tt := range tests {
366		t.Run(tt.name, func(t *testing.T) {
367			t.Parallel()
368			schema := generateSchema(reflect.TypeOf(tt.input))
369			tt.validate(t, schema)
370		})
371	}
372}
373
374func TestGenerateSchemaPointerTypes(t *testing.T) {
375	t.Parallel()
376
377	type StructWithPointers struct {
378		Name *string `json:"name"`
379		Age  *int    `json:"age"`
380	}
381
382	schema := generateSchema(reflect.TypeOf(StructWithPointers{}))
383
384	require.Equal(t, "object", schema.Type)
385
386	require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
387	require.Equal(t, "string", schema.Properties["name"].Type)
388
389	require.NotNil(t, schema.Properties["age"], "Expected age property to exist")
390	require.Equal(t, "integer", schema.Properties["age"].Type)
391}
392
393func TestGenerateSchemaNestedStructs(t *testing.T) {
394	t.Parallel()
395
396	type Address struct {
397		Street string `json:"street"`
398		City   string `json:"city"`
399	}
400
401	type Person struct {
402		Name    string  `json:"name"`
403		Address Address `json:"address"`
404	}
405
406	schema := generateSchema(reflect.TypeOf(Person{}))
407
408	require.Equal(t, "object", schema.Type)
409
410	require.NotNil(t, schema.Properties["address"], "Expected address property to exist")
411
412	addressSchema := schema.Properties["address"]
413	require.Equal(t, "object", addressSchema.Type)
414
415	require.NotNil(t, addressSchema.Properties["street"], "Expected street property in address to exist")
416	require.NotNil(t, addressSchema.Properties["city"], "Expected city property in address to exist")
417}
418
419func TestGenerateSchemaRecursiveStructs(t *testing.T) {
420	t.Parallel()
421
422	type Node struct {
423		Value string `json:"value"`
424		Next  *Node  `json:"next,omitempty"`
425	}
426
427	schema := generateSchema(reflect.TypeOf(Node{}))
428
429	require.Equal(t, "object", schema.Type)
430
431	require.NotNil(t, schema.Properties["value"], "Expected value property to exist")
432
433	require.NotNil(t, schema.Properties["next"], "Expected next property to exist")
434
435	// The recursive reference should be handled gracefully
436	nextSchema := schema.Properties["next"]
437	require.Equal(t, "object", nextSchema.Type)
438}
439
440func TestGenerateSchemaWithEnumTags(t *testing.T) {
441	t.Parallel()
442
443	type ConfigInput struct {
444		Level    string `json:"level" enum:"debug,info,warn,error" description:"Log level"`
445		Format   string `json:"format" enum:"json,text"`
446		Optional string `json:"optional,omitempty" enum:"a,b,c"`
447	}
448
449	schema := generateSchema(reflect.TypeOf(ConfigInput{}))
450
451	// Check level field
452	levelSchema := schema.Properties["level"]
453	require.NotNil(t, levelSchema, "Expected level property to exist")
454	require.Len(t, levelSchema.Enum, 4)
455	expectedLevels := []string{"debug", "info", "warn", "error"}
456	for i, expected := range expectedLevels {
457		require.Equal(t, expected, levelSchema.Enum[i])
458	}
459
460	// Check format field
461	formatSchema := schema.Properties["format"]
462	require.NotNil(t, formatSchema, "Expected format property to exist")
463	require.Len(t, formatSchema.Enum, 2)
464
465	// Check required fields (optional should not be required due to omitempty)
466	expectedRequired := []string{"level", "format"}
467	require.Len(t, schema.Required, len(expectedRequired))
468}
469
470func TestGenerateSchemaComplexTypes(t *testing.T) {
471	t.Parallel()
472
473	type ComplexInput struct {
474		StringSlice []string            `json:"string_slice"`
475		IntMap      map[string]int      `json:"int_map"`
476		NestedSlice []map[string]string `json:"nested_slice"`
477		Interface   any                 `json:"interface"`
478	}
479
480	schema := generateSchema(reflect.TypeOf(ComplexInput{}))
481
482	// Check string slice
483	stringSliceSchema := schema.Properties["string_slice"]
484	require.NotNil(t, stringSliceSchema, "Expected string_slice property to exist")
485	require.Equal(t, "array", stringSliceSchema.Type)
486	require.Equal(t, "string", stringSliceSchema.Items.Type)
487
488	// Check int map
489	intMapSchema := schema.Properties["int_map"]
490	require.NotNil(t, intMapSchema, "Expected int_map property to exist")
491	require.Equal(t, "object", intMapSchema.Type)
492
493	// Check nested slice
494	nestedSliceSchema := schema.Properties["nested_slice"]
495	require.NotNil(t, nestedSliceSchema, "Expected nested_slice property to exist")
496	require.Equal(t, "array", nestedSliceSchema.Type)
497	require.Equal(t, "object", nestedSliceSchema.Items.Type)
498
499	// Check interface
500	interfaceSchema := schema.Properties["interface"]
501	require.NotNil(t, interfaceSchema, "Expected interface property to exist")
502	require.Equal(t, "object", interfaceSchema.Type)
503}
504
505func TestToSnakeCase(t *testing.T) {
506	t.Parallel()
507
508	tests := []struct {
509		input    string
510		expected string
511	}{
512		{"FirstName", "first_name"},
513		{"XMLHttpRequest", "x_m_l_http_request"},
514		{"ID", "i_d"},
515		{"HTTPSProxy", "h_t_t_p_s_proxy"},
516		{"simple", "simple"},
517		{"", ""},
518		{"A", "a"},
519		{"AB", "a_b"},
520		{"CamelCase", "camel_case"},
521	}
522
523	for _, tt := range tests {
524		t.Run(tt.input, func(t *testing.T) {
525			t.Parallel()
526			result := toSnakeCase(tt.input)
527			require.Equal(t, tt.expected, result, "toSnakeCase(%s)", tt.input)
528		})
529	}
530}
531
532func TestSchemaToParametersEdgeCases(t *testing.T) {
533	t.Parallel()
534
535	tests := []struct {
536		name     string
537		schema   Schema
538		expected map[string]any
539	}{
540		{
541			name: "non-object schema",
542			schema: Schema{
543				Type: "string",
544			},
545			expected: map[string]any{},
546		},
547		{
548			name: "object with no properties",
549			schema: Schema{
550				Type:       "object",
551				Properties: nil,
552			},
553			expected: map[string]any{},
554		},
555		{
556			name: "object with empty properties",
557			schema: Schema{
558				Type:       "object",
559				Properties: map[string]*Schema{},
560			},
561			expected: map[string]any{},
562		},
563		{
564			name: "schema with all constraint types",
565			schema: Schema{
566				Type: "object",
567				Properties: map[string]*Schema{
568					"text": {
569						Type:      "string",
570						Format:    "email",
571						MinLength: func() *int { v := 5; return &v }(),
572						MaxLength: func() *int { v := 100; return &v }(),
573					},
574					"number": {
575						Type:    "number",
576						Minimum: func() *float64 { v := 0.0; return &v }(),
577						Maximum: func() *float64 { v := 100.0; return &v }(),
578					},
579				},
580			},
581			expected: map[string]any{
582				"text": map[string]any{
583					"type":      "string",
584					"format":    "email",
585					"minLength": 5,
586					"maxLength": 100,
587				},
588				"number": map[string]any{
589					"type":    "number",
590					"minimum": 0.0,
591					"maximum": 100.0,
592				},
593			},
594		},
595	}
596
597	for _, tt := range tests {
598		t.Run(tt.name, func(t *testing.T) {
599			t.Parallel()
600			result := schemaToParameters(tt.schema)
601			require.Len(t, result, len(tt.expected))
602			for key, expectedValue := range tt.expected {
603				require.NotNil(t, result[key], "Expected parameter %s to exist", key)
604				// Deep comparison would be complex, so we'll check key properties
605				resultParam := result[key].(map[string]any)
606				expectedParam := expectedValue.(map[string]any)
607				for propKey, propValue := range expectedParam {
608					require.Equal(t, propValue, resultParam[propKey], "Expected %s.%s", key, propKey)
609				}
610			}
611		})
612	}
613}