From 9508dfdece6e81946110bb0ea3fccfce5ce35edf Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 23 Aug 2025 19:56:33 +0200 Subject: [PATCH] chore: test schema generation --- internal/ai/tool_test.go | 552 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 552 insertions(+) diff --git a/internal/ai/tool_test.go b/internal/ai/tool_test.go index 8c9caf6389a138358402999bcfa636a5863c389f..3413e7040fa26d55f8505ca8136a38998ee17cfb 100644 --- a/internal/ai/tool_test.go +++ b/internal/ai/tool_test.go @@ -208,3 +208,555 @@ func TestSchemaToParameters(t *testing.T) { t.Errorf("Expected 3 enum values, got %v", enumValues) } } + +func TestGenerateSchemaBasicTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input any + expected Schema + }{ + { + name: "string type", + input: "", + expected: Schema{Type: "string"}, + }, + { + name: "int type", + input: 0, + expected: Schema{Type: "integer"}, + }, + { + name: "int64 type", + input: int64(0), + expected: Schema{Type: "integer"}, + }, + { + name: "uint type", + input: uint(0), + expected: Schema{Type: "integer"}, + }, + { + name: "float64 type", + input: 0.0, + expected: Schema{Type: "number"}, + }, + { + name: "float32 type", + input: float32(0.0), + expected: Schema{Type: "number"}, + }, + { + name: "bool type", + input: false, + expected: Schema{Type: "boolean"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + schema := generateSchema(reflect.TypeOf(tt.input)) + if schema.Type != tt.expected.Type { + t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type) + } + }) + } +} + +func TestGenerateSchemaArrayTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input any + expected Schema + }{ + { + name: "string slice", + input: []string{}, + expected: Schema{ + Type: "array", + Items: &Schema{Type: "string"}, + }, + }, + { + name: "int slice", + input: []int{}, + expected: Schema{ + Type: "array", + Items: &Schema{Type: "integer"}, + }, + }, + { + name: "string array", + input: [3]string{}, + expected: Schema{ + Type: "array", + Items: &Schema{Type: "string"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + schema := generateSchema(reflect.TypeOf(tt.input)) + if schema.Type != tt.expected.Type { + t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type) + } + if schema.Items == nil { + t.Fatal("Expected items schema to exist") + } + if schema.Items.Type != tt.expected.Items.Type { + t.Errorf("Expected items type %s, got %s", tt.expected.Items.Type, schema.Items.Type) + } + }) + } +} + +func TestGenerateSchemaMapTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input any + expected string + }{ + { + name: "string to string map", + input: map[string]string{}, + expected: "object", + }, + { + name: "string to int map", + input: map[string]int{}, + expected: "object", + }, + { + name: "int to string map", + input: map[int]string{}, + expected: "object", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + schema := generateSchema(reflect.TypeOf(tt.input)) + if schema.Type != tt.expected { + t.Errorf("Expected type %s, got %s", tt.expected, schema.Type) + } + }) + } +} + +func TestGenerateSchemaStructTypes(t *testing.T) { + t.Parallel() + + type SimpleStruct struct { + Name string `json:"name" description:"The name field"` + Age int `json:"age"` + } + + type StructWithOmitEmpty struct { + Required string `json:"required"` + Optional string `json:"optional,omitempty"` + } + + type StructWithJSONIgnore struct { + Visible string `json:"visible"` + Hidden string `json:"-"` + } + + type StructWithoutJSONTags struct { + FirstName string + LastName string + } + + tests := []struct { + name string + input any + validate func(t *testing.T, schema Schema) + }{ + { + name: "simple struct", + input: SimpleStruct{}, + validate: func(t *testing.T, schema Schema) { + if schema.Type != "object" { + t.Errorf("Expected type object, got %s", schema.Type) + } + if len(schema.Properties) != 2 { + t.Errorf("Expected 2 properties, got %d", len(schema.Properties)) + } + if schema.Properties["name"] == nil { + t.Error("Expected name property to exist") + } + if schema.Properties["name"].Description != "The name field" { + t.Errorf("Expected description 'The name field', got %s", schema.Properties["name"].Description) + } + if len(schema.Required) != 2 { + t.Errorf("Expected 2 required fields, got %d", len(schema.Required)) + } + }, + }, + { + name: "struct with omitempty", + input: StructWithOmitEmpty{}, + validate: func(t *testing.T, schema Schema) { + if len(schema.Required) != 1 { + t.Errorf("Expected 1 required field, got %d", len(schema.Required)) + } + if schema.Required[0] != "required" { + t.Errorf("Expected required field 'required', got %s", schema.Required[0]) + } + }, + }, + { + name: "struct with json ignore", + input: StructWithJSONIgnore{}, + validate: func(t *testing.T, schema Schema) { + if len(schema.Properties) != 1 { + t.Errorf("Expected 1 property, got %d", len(schema.Properties)) + } + if schema.Properties["visible"] == nil { + t.Error("Expected visible property to exist") + } + if schema.Properties["hidden"] != nil { + t.Error("Expected hidden property to not exist") + } + }, + }, + { + name: "struct without json tags", + input: StructWithoutJSONTags{}, + validate: func(t *testing.T, schema Schema) { + if schema.Properties["first_name"] == nil { + t.Error("Expected first_name property to exist") + } + if schema.Properties["last_name"] == nil { + t.Error("Expected last_name property to exist") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + schema := generateSchema(reflect.TypeOf(tt.input)) + tt.validate(t, schema) + }) + } +} + +func TestGenerateSchemaPointerTypes(t *testing.T) { + t.Parallel() + + type StructWithPointers struct { + Name *string `json:"name"` + Age *int `json:"age"` + } + + schema := generateSchema(reflect.TypeOf(StructWithPointers{})) + + if schema.Type != "object" { + t.Errorf("Expected type object, got %s", schema.Type) + } + + if schema.Properties["name"] == nil { + t.Fatal("Expected name property to exist") + } + if schema.Properties["name"].Type != "string" { + t.Errorf("Expected name type string, got %s", schema.Properties["name"].Type) + } + + if schema.Properties["age"] == nil { + t.Fatal("Expected age property to exist") + } + if schema.Properties["age"].Type != "integer" { + t.Errorf("Expected age type integer, got %s", schema.Properties["age"].Type) + } +} + +func TestGenerateSchemaNestedStructs(t *testing.T) { + t.Parallel() + + type Address struct { + Street string `json:"street"` + City string `json:"city"` + } + + type Person struct { + Name string `json:"name"` + Address Address `json:"address"` + } + + schema := generateSchema(reflect.TypeOf(Person{})) + + if schema.Type != "object" { + t.Errorf("Expected type object, got %s", schema.Type) + } + + if schema.Properties["address"] == nil { + t.Fatal("Expected address property to exist") + } + + addressSchema := schema.Properties["address"] + if addressSchema.Type != "object" { + t.Errorf("Expected address type object, got %s", addressSchema.Type) + } + + if addressSchema.Properties["street"] == nil { + t.Error("Expected street property in address to exist") + } + if addressSchema.Properties["city"] == nil { + t.Error("Expected city property in address to exist") + } +} + +func TestGenerateSchemaRecursiveStructs(t *testing.T) { + t.Parallel() + + type Node struct { + Value string `json:"value"` + Next *Node `json:"next,omitempty"` + } + + schema := generateSchema(reflect.TypeOf(Node{})) + + if schema.Type != "object" { + t.Errorf("Expected type object, got %s", schema.Type) + } + + if schema.Properties["value"] == nil { + t.Error("Expected value property to exist") + } + + if schema.Properties["next"] == nil { + t.Error("Expected next property to exist") + } + + // The recursive reference should be handled gracefully + nextSchema := schema.Properties["next"] + if nextSchema.Type != "object" { + t.Errorf("Expected next type object, got %s", nextSchema.Type) + } +} + +func TestGenerateSchemaWithEnumTags(t *testing.T) { + t.Parallel() + + type ConfigInput struct { + Level string `json:"level" enum:"debug,info,warn,error" description:"Log level"` + Format string `json:"format" enum:"json,text"` + Optional string `json:"optional,omitempty" enum:"a,b,c"` + } + + schema := generateSchema(reflect.TypeOf(ConfigInput{})) + + // Check level field + levelSchema := schema.Properties["level"] + if levelSchema == nil { + t.Fatal("Expected level property to exist") + } + if len(levelSchema.Enum) != 4 { + t.Errorf("Expected 4 enum values for level, got %d", len(levelSchema.Enum)) + } + expectedLevels := []string{"debug", "info", "warn", "error"} + for i, expected := range expectedLevels { + if levelSchema.Enum[i] != expected { + t.Errorf("Expected enum value %s, got %v", expected, levelSchema.Enum[i]) + } + } + + // Check format field + formatSchema := schema.Properties["format"] + if formatSchema == nil { + t.Fatal("Expected format property to exist") + } + if len(formatSchema.Enum) != 2 { + t.Errorf("Expected 2 enum values for format, got %d", len(formatSchema.Enum)) + } + + // Check required fields (optional should not be required due to omitempty) + expectedRequired := []string{"level", "format"} + if len(schema.Required) != len(expectedRequired) { + t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required)) + } +} + +func TestGenerateSchemaComplexTypes(t *testing.T) { + t.Parallel() + + type ComplexInput struct { + StringSlice []string `json:"string_slice"` + IntMap map[string]int `json:"int_map"` + NestedSlice []map[string]string `json:"nested_slice"` + Interface any `json:"interface"` + } + + schema := generateSchema(reflect.TypeOf(ComplexInput{})) + + // Check string slice + stringSliceSchema := schema.Properties["string_slice"] + if stringSliceSchema == nil { + t.Fatal("Expected string_slice property to exist") + } + if stringSliceSchema.Type != "array" { + t.Errorf("Expected string_slice type array, got %s", stringSliceSchema.Type) + } + if stringSliceSchema.Items.Type != "string" { + t.Errorf("Expected string_slice items type string, got %s", stringSliceSchema.Items.Type) + } + + // Check int map + intMapSchema := schema.Properties["int_map"] + if intMapSchema == nil { + t.Fatal("Expected int_map property to exist") + } + if intMapSchema.Type != "object" { + t.Errorf("Expected int_map type object, got %s", intMapSchema.Type) + } + + // Check nested slice + nestedSliceSchema := schema.Properties["nested_slice"] + if nestedSliceSchema == nil { + t.Fatal("Expected nested_slice property to exist") + } + if nestedSliceSchema.Type != "array" { + t.Errorf("Expected nested_slice type array, got %s", nestedSliceSchema.Type) + } + if nestedSliceSchema.Items.Type != "object" { + t.Errorf("Expected nested_slice items type object, got %s", nestedSliceSchema.Items.Type) + } + + // Check interface + interfaceSchema := schema.Properties["interface"] + if interfaceSchema == nil { + t.Fatal("Expected interface property to exist") + } + if interfaceSchema.Type != "object" { + t.Errorf("Expected interface type object, got %s", interfaceSchema.Type) + } +} + +func TestToSnakeCase(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + expected string + }{ + {"FirstName", "first_name"}, + {"XMLHttpRequest", "x_m_l_http_request"}, + {"ID", "i_d"}, + {"HTTPSProxy", "h_t_t_p_s_proxy"}, + {"simple", "simple"}, + {"", ""}, + {"A", "a"}, + {"AB", "a_b"}, + {"CamelCase", "camel_case"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + result := toSnakeCase(tt.input) + if result != tt.expected { + t.Errorf("toSnakeCase(%s) = %s, expected %s", tt.input, result, tt.expected) + } + }) + } +} + +func TestSchemaToParametersEdgeCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + schema Schema + expected map[string]any + }{ + { + name: "non-object schema", + schema: Schema{ + Type: "string", + }, + expected: map[string]any{}, + }, + { + name: "object with no properties", + schema: Schema{ + Type: "object", + Properties: nil, + }, + expected: map[string]any{}, + }, + { + name: "object with empty properties", + schema: Schema{ + Type: "object", + Properties: map[string]*Schema{}, + }, + expected: map[string]any{}, + }, + { + name: "schema with all constraint types", + schema: Schema{ + Type: "object", + Properties: map[string]*Schema{ + "text": { + Type: "string", + Format: "email", + MinLength: func() *int { v := 5; return &v }(), + MaxLength: func() *int { v := 100; return &v }(), + }, + "number": { + Type: "number", + Minimum: func() *float64 { v := 0.0; return &v }(), + Maximum: func() *float64 { v := 100.0; return &v }(), + }, + }, + }, + expected: map[string]any{ + "text": map[string]any{ + "type": "string", + "format": "email", + "minLength": 5, + "maxLength": 100, + }, + "number": map[string]any{ + "type": "number", + "minimum": 0.0, + "maximum": 100.0, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := schemaToParameters(tt.schema) + if len(result) != len(tt.expected) { + t.Errorf("Expected %d parameters, got %d", len(tt.expected), len(result)) + } + for key, expectedValue := range tt.expected { + if result[key] == nil { + t.Errorf("Expected parameter %s to exist", key) + continue + } + // Deep comparison would be complex, so we'll check key properties + resultParam := result[key].(map[string]any) + expectedParam := expectedValue.(map[string]any) + for propKey, propValue := range expectedParam { + if resultParam[propKey] != propValue { + t.Errorf("Expected %s.%s = %v, got %v", key, propKey, propValue, resultParam[propKey]) + } + } + } + }) + } +}