1// Package schema provides JSON schema generation and validation utilities.
2// It supports automatic schema generation from Go types and validation of parsed objects.
3package schema
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "reflect"
10 "slices"
11 "strings"
12
13 jsonrepair "github.com/RealAlexandreAI/json-repair"
14 "github.com/kaptinlin/jsonschema"
15)
16
17// ObjectRepairFunc is a function that attempts to repair invalid JSON output.
18// It receives the raw text and the error encountered during parsing or validation,
19// and returns repaired text or an error if repair is not possible.
20type ObjectRepairFunc func(ctx context.Context, text string, err error) (string, error)
21
22// ParseError is returned when object generation fails
23// due to parsing errors, validation errors, or model failures.
24type ParseError struct {
25 RawText string
26 ParseError error
27 ValidationError error
28}
29
30// Schema represents a JSON schema for tool input validation.
31type Schema struct {
32 Type string `json:"type,omitempty"`
33 Properties map[string]*Schema `json:"properties,omitempty"`
34 Required []string `json:"required,omitempty"`
35 Items *Schema `json:"items,omitempty"`
36 Description string `json:"description,omitempty"`
37 Enum []any `json:"enum,omitempty"`
38 Format string `json:"format,omitempty"`
39 Minimum *float64 `json:"minimum,omitempty"`
40 Maximum *float64 `json:"maximum,omitempty"`
41 MinLength *int `json:"minLength,omitempty"`
42 MaxLength *int `json:"maxLength,omitempty"`
43}
44
45// ParseState represents the state of JSON parsing.
46type ParseState string
47
48const (
49 // ParseStateUndefined means input was undefined/empty.
50 ParseStateUndefined ParseState = "undefined"
51
52 // ParseStateSuccessful means JSON parsed without repair.
53 ParseStateSuccessful ParseState = "successful"
54
55 // ParseStateRepaired means JSON parsed after repair.
56 ParseStateRepaired ParseState = "repaired"
57
58 // ParseStateFailed means JSON could not be parsed even after repair.
59 ParseStateFailed ParseState = "failed"
60)
61
62// ToParameters converts a Schema to the parameters map format expected by ToolInfo.
63func ToParameters(s Schema) map[string]any {
64 if s.Properties == nil {
65 return make(map[string]any)
66 }
67
68 result := make(map[string]any)
69 for name, propSchema := range s.Properties {
70 result[name] = ToMap(*propSchema)
71 }
72 return result
73}
74
75// Generate generates a JSON schema from a reflect.Type.
76// It recursively processes struct fields, arrays, maps, and primitive types.
77func Generate(t reflect.Type) Schema {
78 return generateSchemaRecursive(t, make(map[reflect.Type]bool))
79}
80
81func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
82 if t.Kind() == reflect.Pointer {
83 t = t.Elem()
84 }
85
86 if visited[t] {
87 return Schema{Type: "object"}
88 }
89 visited[t] = true
90 defer delete(visited, t)
91
92 switch t.Kind() {
93 case reflect.String:
94 return Schema{Type: "string"}
95 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
96 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
97 return Schema{Type: "integer"}
98 case reflect.Float32, reflect.Float64:
99 return Schema{Type: "number"}
100 case reflect.Bool:
101 return Schema{Type: "boolean"}
102 case reflect.Slice, reflect.Array:
103 itemSchema := generateSchemaRecursive(t.Elem(), visited)
104 return Schema{
105 Type: "array",
106 Items: &itemSchema,
107 }
108 case reflect.Map:
109 if t.Key().Kind() == reflect.String {
110 valueSchema := generateSchemaRecursive(t.Elem(), visited)
111 schema := Schema{
112 Type: "object",
113 Properties: map[string]*Schema{
114 "*": &valueSchema,
115 },
116 }
117 return schema
118 }
119 return Schema{Type: "object"}
120 case reflect.Struct:
121 schema := Schema{
122 Type: "object",
123 Properties: make(map[string]*Schema),
124 }
125 for i := range t.NumField() {
126 field := t.Field(i)
127
128 if !field.IsExported() {
129 continue
130 }
131
132 jsonTag := field.Tag.Get("json")
133 if jsonTag == "-" {
134 continue
135 }
136
137 fieldName := field.Name
138 required := true
139
140 if jsonTag != "" {
141 parts := strings.Split(jsonTag, ",")
142 if parts[0] != "" {
143 fieldName = parts[0]
144 }
145
146 if slices.Contains(parts[1:], "omitempty") {
147 required = false
148 }
149 } else {
150 fieldName = toSnakeCase(fieldName)
151 }
152
153 fieldSchema := generateSchemaRecursive(field.Type, visited)
154
155 if desc := field.Tag.Get("description"); desc != "" {
156 fieldSchema.Description = desc
157 }
158
159 if enumTag := field.Tag.Get("enum"); enumTag != "" {
160 enumValues := strings.Split(enumTag, ",")
161 fieldSchema.Enum = make([]any, len(enumValues))
162 for i, v := range enumValues {
163 fieldSchema.Enum[i] = strings.TrimSpace(v)
164 }
165 }
166
167 schema.Properties[fieldName] = &fieldSchema
168
169 if required {
170 schema.Required = append(schema.Required, fieldName)
171 }
172 }
173
174 return schema
175 case reflect.Interface:
176 return Schema{Type: "object"}
177 default:
178 return Schema{Type: "object"}
179 }
180}
181
182// ToMap converts a Schema to a map representation suitable for JSON Schema.
183func ToMap(schema Schema) map[string]any {
184 result := make(map[string]any)
185
186 if schema.Type != "" {
187 result["type"] = schema.Type
188 }
189
190 if schema.Description != "" {
191 result["description"] = schema.Description
192 }
193
194 if len(schema.Enum) > 0 {
195 result["enum"] = schema.Enum
196 }
197
198 if schema.Format != "" {
199 result["format"] = schema.Format
200 }
201
202 if schema.Minimum != nil {
203 result["minimum"] = *schema.Minimum
204 }
205
206 if schema.Maximum != nil {
207 result["maximum"] = *schema.Maximum
208 }
209
210 if schema.MinLength != nil {
211 result["minLength"] = *schema.MinLength
212 }
213
214 if schema.MaxLength != nil {
215 result["maxLength"] = *schema.MaxLength
216 }
217
218 if schema.Properties != nil {
219 props := make(map[string]any)
220 for name, propSchema := range schema.Properties {
221 props[name] = ToMap(*propSchema)
222 }
223 result["properties"] = props
224 }
225
226 if len(schema.Required) > 0 {
227 result["required"] = schema.Required
228 }
229
230 if schema.Items != nil {
231 itemsMap := ToMap(*schema.Items)
232 // Ensure type is always set for items, even if it was blank for llama.cpp compatibility
233 if _, hasType := itemsMap["type"]; !hasType && schema.Items.Type == "" {
234 if len(schema.Items.Properties) > 0 {
235 itemsMap["type"] = "object"
236 }
237 }
238 result["items"] = itemsMap
239 }
240
241 return result
242}
243
244// ParsePartialJSON attempts to parse potentially incomplete JSON.
245// It first tries standard JSON parsing, then attempts repair if that fails.
246//
247// Returns:
248// - result: The parsed JSON value (map, slice, or primitive)
249// - state: Indicates whether parsing succeeded, needed repair, or failed
250// - err: The error if parsing failed completely
251//
252// Example:
253//
254// obj, state, err := ParsePartialJSON(`{"name": "John", "age": 25`)
255// // Result: map[string]any{"name": "John", "age": 25}, ParseStateRepaired, nil
256func ParsePartialJSON(text string) (any, ParseState, error) {
257 if text == "" {
258 return nil, ParseStateUndefined, nil
259 }
260
261 var result any
262 if err := json.Unmarshal([]byte(text), &result); err == nil {
263 return result, ParseStateSuccessful, nil
264 }
265
266 repaired, err := jsonrepair.RepairJSON(text)
267 if err != nil {
268 return nil, ParseStateFailed, fmt.Errorf("json repair failed: %w", err)
269 }
270
271 if err := json.Unmarshal([]byte(repaired), &result); err != nil {
272 return nil, ParseStateFailed, fmt.Errorf("failed to parse repaired json: %w", err)
273 }
274
275 return result, ParseStateRepaired, nil
276}
277
278// Error implements the error interface.
279func (e *ParseError) Error() string {
280 if e.ValidationError != nil {
281 return fmt.Sprintf("object validation failed: %v", e.ValidationError)
282 }
283 if e.ParseError != nil {
284 return fmt.Sprintf("failed to parse object: %v", e.ParseError)
285 }
286 return "failed to generate object"
287}
288
289// ParseAndValidate combines JSON parsing and validation.
290// Returns the parsed object if both parsing and validation succeed.
291func ParseAndValidate(text string, schema Schema) (any, error) {
292 obj, state, err := ParsePartialJSON(text)
293 if state == ParseStateFailed {
294 return nil, &ParseError{
295 RawText: text,
296 ParseError: err,
297 }
298 }
299
300 if err := validateAgainstSchema(obj, schema); err != nil {
301 return nil, &ParseError{
302 RawText: text,
303 ValidationError: err,
304 }
305 }
306
307 return obj, nil
308}
309
310// ValidateAgainstSchema validates a parsed object against a Schema.
311func ValidateAgainstSchema(obj any, schema Schema) error {
312 return validateAgainstSchema(obj, schema)
313}
314
315func validateAgainstSchema(obj any, schema Schema) error {
316 jsonSchemaBytes, err := json.Marshal(schema)
317 if err != nil {
318 return fmt.Errorf("failed to marshal schema: %w", err)
319 }
320
321 compiler := jsonschema.NewCompiler()
322 validator, err := compiler.Compile(jsonSchemaBytes)
323 if err != nil {
324 return fmt.Errorf("invalid schema: %w", err)
325 }
326
327 result := validator.Validate(obj)
328 if !result.IsValid() {
329 var errMsgs []string
330 for field, validationErr := range result.Errors {
331 errMsgs = append(errMsgs, fmt.Sprintf("%s: %s", field, validationErr.Message))
332 }
333 return fmt.Errorf("validation failed: %s", strings.Join(errMsgs, "; "))
334 }
335
336 return nil
337}
338
339// ParseAndValidateWithRepair attempts parsing, validation, and custom repair.
340func ParseAndValidateWithRepair(
341 ctx context.Context,
342 text string,
343 schema Schema,
344 repair ObjectRepairFunc,
345) (any, error) {
346 obj, state, parseErr := ParsePartialJSON(text)
347
348 if state == ParseStateSuccessful || state == ParseStateRepaired {
349 validationErr := validateAgainstSchema(obj, schema)
350 if validationErr == nil {
351 return obj, nil
352 }
353
354 if repair != nil {
355 repairedText, repairErr := repair(ctx, text, validationErr)
356 if repairErr == nil {
357 obj2, state2, _ := ParsePartialJSON(repairedText)
358 if state2 == ParseStateSuccessful || state2 == ParseStateRepaired {
359 if err := validateAgainstSchema(obj2, schema); err == nil {
360 return obj2, nil
361 }
362 }
363 }
364 }
365
366 return nil, &ParseError{
367 RawText: text,
368 ValidationError: validationErr,
369 }
370 }
371
372 if repair != nil {
373 repairedText, repairErr := repair(ctx, text, parseErr)
374 if repairErr == nil {
375 obj2, state2, parseErr2 := ParsePartialJSON(repairedText)
376 if state2 == ParseStateSuccessful || state2 == ParseStateRepaired {
377 if err := validateAgainstSchema(obj2, schema); err == nil {
378 return obj2, nil
379 }
380 }
381 return nil, &ParseError{
382 RawText: repairedText,
383 ParseError: parseErr2,
384 }
385 }
386 }
387
388 return nil, &ParseError{
389 RawText: text,
390 ParseError: parseErr,
391 }
392}
393
394func toSnakeCase(s string) string {
395 var result strings.Builder
396 for i, r := range s {
397 if i > 0 && r >= 'A' && r <= 'Z' {
398 result.WriteByte('_')
399 }
400 result.WriteRune(r)
401 }
402 return strings.ToLower(result.String())
403}