schema.go

  1// Package schema provides JSON schema generation and validation utilities.
  2// It supports automatic schema generation from Go types and validation of parsed objects.
  3package schema
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"reflect"
 10	"slices"
 11	"strings"
 12
 13	jsonrepair "github.com/RealAlexandreAI/json-repair"
 14	"github.com/kaptinlin/jsonschema"
 15)
 16
 17// ObjectRepairFunc is a function that attempts to repair invalid JSON output.
 18// It receives the raw text and the error encountered during parsing or validation,
 19// and returns repaired text or an error if repair is not possible.
 20type ObjectRepairFunc func(ctx context.Context, text string, err error) (string, error)
 21
 22// ParseError is returned when object generation fails
 23// due to parsing errors, validation errors, or model failures.
 24type ParseError struct {
 25	RawText         string
 26	ParseError      error
 27	ValidationError error
 28}
 29
 30// Schema represents a JSON schema for tool input validation.
 31type Schema struct {
 32	Type        string             `json:"type,omitempty"`
 33	Properties  map[string]*Schema `json:"properties,omitempty"`
 34	Required    []string           `json:"required,omitempty"`
 35	Items       *Schema            `json:"items,omitempty"`
 36	Description string             `json:"description,omitempty"`
 37	Enum        []any              `json:"enum,omitempty"`
 38	Format      string             `json:"format,omitempty"`
 39	Minimum     *float64           `json:"minimum,omitempty"`
 40	Maximum     *float64           `json:"maximum,omitempty"`
 41	MinLength   *int               `json:"minLength,omitempty"`
 42	MaxLength   *int               `json:"maxLength,omitempty"`
 43}
 44
 45// ParseState represents the state of JSON parsing.
 46type ParseState string
 47
 48const (
 49	// ParseStateUndefined means input was undefined/empty.
 50	ParseStateUndefined ParseState = "undefined"
 51
 52	// ParseStateSuccessful means JSON parsed without repair.
 53	ParseStateSuccessful ParseState = "successful"
 54
 55	// ParseStateRepaired means JSON parsed after repair.
 56	ParseStateRepaired ParseState = "repaired"
 57
 58	// ParseStateFailed means JSON could not be parsed even after repair.
 59	ParseStateFailed ParseState = "failed"
 60)
 61
 62// ToParameters converts a Schema to the parameters map format expected by ToolInfo.
 63func ToParameters(s Schema) map[string]any {
 64	if s.Properties == nil {
 65		return make(map[string]any)
 66	}
 67
 68	result := make(map[string]any)
 69	for name, propSchema := range s.Properties {
 70		result[name] = ToMap(*propSchema)
 71	}
 72	return result
 73}
 74
 75// Generate generates a JSON schema from a reflect.Type.
 76// It recursively processes struct fields, arrays, maps, and primitive types.
 77func Generate(t reflect.Type) Schema {
 78	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
 79}
 80
 81func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
 82	if t.Kind() == reflect.Pointer {
 83		t = t.Elem()
 84	}
 85
 86	if visited[t] {
 87		return Schema{Type: "object"}
 88	}
 89	visited[t] = true
 90	defer delete(visited, t)
 91
 92	switch t.Kind() {
 93	case reflect.String:
 94		return Schema{Type: "string"}
 95	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
 96		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 97		return Schema{Type: "integer"}
 98	case reflect.Float32, reflect.Float64:
 99		return Schema{Type: "number"}
100	case reflect.Bool:
101		return Schema{Type: "boolean"}
102	case reflect.Slice, reflect.Array:
103		itemSchema := generateSchemaRecursive(t.Elem(), visited)
104		return Schema{
105			Type:  "array",
106			Items: &itemSchema,
107		}
108	case reflect.Map:
109		if t.Key().Kind() == reflect.String {
110			valueSchema := generateSchemaRecursive(t.Elem(), visited)
111			schema := Schema{
112				Type: "object",
113				Properties: map[string]*Schema{
114					"*": &valueSchema,
115				},
116			}
117			return schema
118		}
119		return Schema{Type: "object"}
120	case reflect.Struct:
121		schema := Schema{
122			Type:       "object",
123			Properties: make(map[string]*Schema),
124		}
125		for i := range t.NumField() {
126			field := t.Field(i)
127
128			if !field.IsExported() {
129				continue
130			}
131
132			jsonTag := field.Tag.Get("json")
133			if jsonTag == "-" {
134				continue
135			}
136
137			fieldName := field.Name
138			required := true
139
140			if jsonTag != "" {
141				parts := strings.Split(jsonTag, ",")
142				if parts[0] != "" {
143					fieldName = parts[0]
144				}
145
146				if slices.Contains(parts[1:], "omitempty") {
147					required = false
148				}
149			} else {
150				fieldName = toSnakeCase(fieldName)
151			}
152
153			fieldSchema := generateSchemaRecursive(field.Type, visited)
154
155			if desc := field.Tag.Get("description"); desc != "" {
156				fieldSchema.Description = desc
157			}
158
159			if enumTag := field.Tag.Get("enum"); enumTag != "" {
160				enumValues := strings.Split(enumTag, ",")
161				fieldSchema.Enum = make([]any, len(enumValues))
162				for i, v := range enumValues {
163					fieldSchema.Enum[i] = strings.TrimSpace(v)
164				}
165			}
166
167			schema.Properties[fieldName] = &fieldSchema
168
169			if required {
170				schema.Required = append(schema.Required, fieldName)
171			}
172		}
173
174		return schema
175	case reflect.Interface:
176		return Schema{Type: "object"}
177	default:
178		return Schema{Type: "object"}
179	}
180}
181
182// ToMap converts a Schema to a map representation suitable for JSON Schema.
183func ToMap(schema Schema) map[string]any {
184	result := make(map[string]any)
185
186	if schema.Type != "" {
187		result["type"] = schema.Type
188	}
189
190	if schema.Description != "" {
191		result["description"] = schema.Description
192	}
193
194	if len(schema.Enum) > 0 {
195		result["enum"] = schema.Enum
196	}
197
198	if schema.Format != "" {
199		result["format"] = schema.Format
200	}
201
202	if schema.Minimum != nil {
203		result["minimum"] = *schema.Minimum
204	}
205
206	if schema.Maximum != nil {
207		result["maximum"] = *schema.Maximum
208	}
209
210	if schema.MinLength != nil {
211		result["minLength"] = *schema.MinLength
212	}
213
214	if schema.MaxLength != nil {
215		result["maxLength"] = *schema.MaxLength
216	}
217
218	if schema.Properties != nil {
219		props := make(map[string]any)
220		for name, propSchema := range schema.Properties {
221			props[name] = ToMap(*propSchema)
222		}
223		result["properties"] = props
224	}
225
226	if len(schema.Required) > 0 {
227		result["required"] = schema.Required
228	}
229
230	if schema.Items != nil {
231		itemsMap := ToMap(*schema.Items)
232		// Ensure type is always set for items, even if it was blank for llama.cpp compatibility
233		if _, hasType := itemsMap["type"]; !hasType && schema.Items.Type == "" {
234			if len(schema.Items.Properties) > 0 {
235				itemsMap["type"] = "object"
236			}
237		}
238		result["items"] = itemsMap
239	}
240
241	return result
242}
243
244// ParsePartialJSON attempts to parse potentially incomplete JSON.
245// It first tries standard JSON parsing, then attempts repair if that fails.
246//
247// Returns:
248//   - result: The parsed JSON value (map, slice, or primitive)
249//   - state: Indicates whether parsing succeeded, needed repair, or failed
250//   - err: The error if parsing failed completely
251//
252// Example:
253//
254//	obj, state, err := ParsePartialJSON(`{"name": "John", "age": 25`)
255//	// Result: map[string]any{"name": "John", "age": 25}, ParseStateRepaired, nil
256func ParsePartialJSON(text string) (any, ParseState, error) {
257	if text == "" {
258		return nil, ParseStateUndefined, nil
259	}
260
261	var result any
262	if err := json.Unmarshal([]byte(text), &result); err == nil {
263		return result, ParseStateSuccessful, nil
264	}
265
266	repaired, err := jsonrepair.RepairJSON(text)
267	if err != nil {
268		return nil, ParseStateFailed, fmt.Errorf("json repair failed: %w", err)
269	}
270
271	if err := json.Unmarshal([]byte(repaired), &result); err != nil {
272		return nil, ParseStateFailed, fmt.Errorf("failed to parse repaired json: %w", err)
273	}
274
275	return result, ParseStateRepaired, nil
276}
277
278// Error implements the error interface.
279func (e *ParseError) Error() string {
280	if e.ValidationError != nil {
281		return fmt.Sprintf("object validation failed: %v", e.ValidationError)
282	}
283	if e.ParseError != nil {
284		return fmt.Sprintf("failed to parse object: %v", e.ParseError)
285	}
286	return "failed to generate object"
287}
288
289// ParseAndValidate combines JSON parsing and validation.
290// Returns the parsed object if both parsing and validation succeed.
291func ParseAndValidate(text string, schema Schema) (any, error) {
292	obj, state, err := ParsePartialJSON(text)
293	if state == ParseStateFailed {
294		return nil, &ParseError{
295			RawText:    text,
296			ParseError: err,
297		}
298	}
299
300	if err := validateAgainstSchema(obj, schema); err != nil {
301		return nil, &ParseError{
302			RawText:         text,
303			ValidationError: err,
304		}
305	}
306
307	return obj, nil
308}
309
310// ValidateAgainstSchema validates a parsed object against a Schema.
311func ValidateAgainstSchema(obj any, schema Schema) error {
312	return validateAgainstSchema(obj, schema)
313}
314
315func validateAgainstSchema(obj any, schema Schema) error {
316	jsonSchemaBytes, err := json.Marshal(schema)
317	if err != nil {
318		return fmt.Errorf("failed to marshal schema: %w", err)
319	}
320
321	compiler := jsonschema.NewCompiler()
322	validator, err := compiler.Compile(jsonSchemaBytes)
323	if err != nil {
324		return fmt.Errorf("invalid schema: %w", err)
325	}
326
327	result := validator.Validate(obj)
328	if !result.IsValid() {
329		var errMsgs []string
330		for field, validationErr := range result.Errors {
331			errMsgs = append(errMsgs, fmt.Sprintf("%s: %s", field, validationErr.Message))
332		}
333		return fmt.Errorf("validation failed: %s", strings.Join(errMsgs, "; "))
334	}
335
336	return nil
337}
338
339// ParseAndValidateWithRepair attempts parsing, validation, and custom repair.
340func ParseAndValidateWithRepair(
341	ctx context.Context,
342	text string,
343	schema Schema,
344	repair ObjectRepairFunc,
345) (any, error) {
346	obj, state, parseErr := ParsePartialJSON(text)
347
348	if state == ParseStateSuccessful || state == ParseStateRepaired {
349		validationErr := validateAgainstSchema(obj, schema)
350		if validationErr == nil {
351			return obj, nil
352		}
353
354		if repair != nil {
355			repairedText, repairErr := repair(ctx, text, validationErr)
356			if repairErr == nil {
357				obj2, state2, _ := ParsePartialJSON(repairedText)
358				if state2 == ParseStateSuccessful || state2 == ParseStateRepaired {
359					if err := validateAgainstSchema(obj2, schema); err == nil {
360						return obj2, nil
361					}
362				}
363			}
364		}
365
366		return nil, &ParseError{
367			RawText:         text,
368			ValidationError: validationErr,
369		}
370	}
371
372	if repair != nil {
373		repairedText, repairErr := repair(ctx, text, parseErr)
374		if repairErr == nil {
375			obj2, state2, parseErr2 := ParsePartialJSON(repairedText)
376			if state2 == ParseStateSuccessful || state2 == ParseStateRepaired {
377				if err := validateAgainstSchema(obj2, schema); err == nil {
378					return obj2, nil
379				}
380			}
381			return nil, &ParseError{
382				RawText:    repairedText,
383				ParseError: parseErr2,
384			}
385		}
386	}
387
388	return nil, &ParseError{
389		RawText:    text,
390		ParseError: parseErr,
391	}
392}
393
394func toSnakeCase(s string) string {
395	var result strings.Builder
396	for i, r := range s {
397		if i > 0 && r >= 'A' && r <= 'Z' {
398			result.WriteByte('_')
399		}
400		result.WriteRune(r)
401	}
402	return strings.ToLower(result.String())
403}