fix(llama.cpp): fix invalid json schema error for llama.cpp (#50)

Andrey Nering created

Change summary

tool.go      | 36 +++++++++++++++++++++++++++++-------
tool_test.go |  2 +-
2 files changed, 30 insertions(+), 8 deletions(-)

Detailed changes

tool.go 🔗

@@ -11,7 +11,7 @@ import (
 
 // Schema represents a JSON schema for tool input validation.
 type Schema struct {
-	Type        string             `json:"type"`
+	Type        string             `json:"type,omitempty"`
 	Properties  map[string]*Schema `json:"properties,omitempty"`
 	Required    []string           `json:"required,omitempty"`
 	Items       *Schema            `json:"items,omitempty"`
@@ -193,10 +193,10 @@ func schemaToParameters(schema Schema) map[string]any {
 
 // generateSchema automatically generates a JSON schema from a Go type.
 func generateSchema(t reflect.Type) Schema {
-	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
+	return generateSchemaRecursive(t, nil, make(map[reflect.Type]bool))
 }
 
-func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
+func generateSchemaRecursive(t, parent reflect.Type, visited map[reflect.Type]bool) Schema {
 	// Handle pointers
 	if t.Kind() == reflect.Pointer {
 		t = t.Elem()
@@ -220,20 +220,24 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
 	case reflect.Bool:
 		return Schema{Type: "boolean"}
 	case reflect.Slice, reflect.Array:
-		itemSchema := generateSchemaRecursive(t.Elem(), visited)
+		itemSchema := generateSchemaRecursive(t.Elem(), t, visited)
 		return Schema{
 			Type:  "array",
 			Items: &itemSchema,
 		}
 	case reflect.Map:
 		if t.Key().Kind() == reflect.String {
-			valueSchema := generateSchemaRecursive(t.Elem(), visited)
-			return Schema{
+			valueSchema := generateSchemaRecursive(t.Elem(), t, visited)
+			schema := Schema{
 				Type: "object",
 				Properties: map[string]*Schema{
 					"*": &valueSchema,
 				},
 			}
+			if useBlankType(parent) {
+				schema.Type = ""
+			}
+			return schema
 		}
 		return Schema{Type: "object"}
 	case reflect.Struct:
@@ -241,6 +245,9 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
 			Type:       "object",
 			Properties: make(map[string]*Schema),
 		}
+		if useBlankType(parent) {
+			schema.Type = ""
+		}
 
 		for i := range t.NumField() {
 			field := t.Field(i)
@@ -274,7 +281,7 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
 				fieldName = toSnakeCase(fieldName)
 			}
 
-			fieldSchema := generateSchemaRecursive(field.Type, visited)
+			fieldSchema := generateSchemaRecursive(field.Type, t, visited)
 
 			// Add description from struct tag if available
 			if desc := field.Tag.Get("description"); desc != "" {
@@ -316,3 +323,18 @@ func toSnakeCase(s string) string {
 	}
 	return strings.ToLower(result.String())
 }
+
+// NOTE(@andreynering): This is a hacky workaround for llama.cpp.
+// Ideally, we should always output `type: object` for objects, but
+// llama.cpp complains if we do for arrays of objects.
+func useBlankType(parent reflect.Type) bool {
+	if parent == nil {
+		return false
+	}
+	switch parent.Kind() {
+	case reflect.Slice, reflect.Array:
+		return true
+	default:
+		return false
+	}
+}

tool_test.go 🔗

@@ -494,7 +494,7 @@ func TestGenerateSchemaComplexTypes(t *testing.T) {
 	nestedSliceSchema := schema.Properties["nested_slice"]
 	require.NotNil(t, nestedSliceSchema, "Expected nested_slice property to exist")
 	require.Equal(t, "array", nestedSliceSchema.Type)
-	require.Equal(t, "object", nestedSliceSchema.Items.Type)
+	require.Equal(t, "", nestedSliceSchema.Items.Type)
 
 	// Check interface
 	interfaceSchema := schema.Properties["interface"]