@@ -11,7 +11,7 @@ import (
// Schema represents a JSON schema for tool input validation.
type Schema struct {
- Type string `json:"type"`
+ Type string `json:"type,omitempty"`
Properties map[string]*Schema `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Items *Schema `json:"items,omitempty"`
@@ -193,10 +193,10 @@ func schemaToParameters(schema Schema) map[string]any {
// generateSchema automatically generates a JSON schema from a Go type.
func generateSchema(t reflect.Type) Schema {
- return generateSchemaRecursive(t, make(map[reflect.Type]bool))
+ return generateSchemaRecursive(t, nil, make(map[reflect.Type]bool))
}
-func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
+func generateSchemaRecursive(t, parent reflect.Type, visited map[reflect.Type]bool) Schema {
// Handle pointers
if t.Kind() == reflect.Pointer {
t = t.Elem()
@@ -220,20 +220,24 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
case reflect.Bool:
return Schema{Type: "boolean"}
case reflect.Slice, reflect.Array:
- itemSchema := generateSchemaRecursive(t.Elem(), visited)
+ itemSchema := generateSchemaRecursive(t.Elem(), t, visited)
return Schema{
Type: "array",
Items: &itemSchema,
}
case reflect.Map:
if t.Key().Kind() == reflect.String {
- valueSchema := generateSchemaRecursive(t.Elem(), visited)
- return Schema{
+ valueSchema := generateSchemaRecursive(t.Elem(), t, visited)
+ schema := Schema{
Type: "object",
Properties: map[string]*Schema{
"*": &valueSchema,
},
}
+ if useBlankType(parent) {
+ schema.Type = ""
+ }
+ return schema
}
return Schema{Type: "object"}
case reflect.Struct:
@@ -241,6 +245,9 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
Type: "object",
Properties: make(map[string]*Schema),
}
+ if useBlankType(parent) {
+ schema.Type = ""
+ }
for i := range t.NumField() {
field := t.Field(i)
@@ -274,7 +281,7 @@ func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Sche
fieldName = toSnakeCase(fieldName)
}
- fieldSchema := generateSchemaRecursive(field.Type, visited)
+ fieldSchema := generateSchemaRecursive(field.Type, t, visited)
// Add description from struct tag if available
if desc := field.Tag.Get("description"); desc != "" {
@@ -316,3 +323,18 @@ func toSnakeCase(s string) string {
}
return strings.ToLower(result.String())
}
+
+// NOTE(@andreynering): This is a hacky workaround for llama.cpp.
+// Ideally, we should always output `type: object` for objects, but
+// llama.cpp complains if we do for arrays of objects.
+func useBlankType(parent reflect.Type) bool {
+ if parent == nil {
+ return false
+ }
+ switch parent.Kind() {
+ case reflect.Slice, reflect.Array:
+ return true
+ default:
+ return false
+ }
+}
@@ -494,7 +494,7 @@ func TestGenerateSchemaComplexTypes(t *testing.T) {
nestedSliceSchema := schema.Properties["nested_slice"]
require.NotNil(t, nestedSliceSchema, "Expected nested_slice property to exist")
require.Equal(t, "array", nestedSliceSchema.Type)
- require.Equal(t, "object", nestedSliceSchema.Items.Type)
+ require.Equal(t, "", nestedSliceSchema.Items.Type)
// Check interface
interfaceSchema := schema.Properties["interface"]