1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"strings"
  9)
 10
 11// Schema represents a JSON schema for tool input validation.
 12type Schema struct {
 13	Type        string             `json:"type"`
 14	Properties  map[string]*Schema `json:"properties,omitempty"`
 15	Required    []string           `json:"required,omitempty"`
 16	Items       *Schema            `json:"items,omitempty"`
 17	Description string             `json:"description,omitempty"`
 18	Enum        []any              `json:"enum,omitempty"`
 19	Format      string             `json:"format,omitempty"`
 20	Minimum     *float64           `json:"minimum,omitempty"`
 21	Maximum     *float64           `json:"maximum,omitempty"`
 22	MinLength   *int               `json:"minLength,omitempty"`
 23	MaxLength   *int               `json:"maxLength,omitempty"`
 24}
 25
 26// ToolInfo represents tool metadata, matching the existing pattern.
 27type ToolInfo struct {
 28	Name        string
 29	Description string
 30	Parameters  map[string]any
 31	Required    []string
 32}
 33
 34// ToolCall represents a tool invocation, matching the existing pattern.
 35type ToolCall struct {
 36	ID    string `json:"id"`
 37	Name  string `json:"name"`
 38	Input string `json:"input"`
 39}
 40
 41// ToolResponse represents the response from a tool execution, matching the existing pattern.
 42type ToolResponse struct {
 43	Type     string `json:"type"`
 44	Content  string `json:"content"`
 45	Metadata string `json:"metadata,omitempty"`
 46	IsError  bool   `json:"is_error"`
 47}
 48
 49// NewTextResponse creates a text response.
 50func NewTextResponse(content string) ToolResponse {
 51	return ToolResponse{
 52		Type:    "text",
 53		Content: content,
 54	}
 55}
 56
 57// NewTextErrorResponse creates an error response.
 58func NewTextErrorResponse(content string) ToolResponse {
 59	return ToolResponse{
 60		Type:    "text",
 61		Content: content,
 62		IsError: true,
 63	}
 64}
 65
 66// WithResponseMetadata adds metadata to a response.
 67func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 68	if metadata != nil {
 69		metadataBytes, err := json.Marshal(metadata)
 70		if err != nil {
 71			return response
 72		}
 73		response.Metadata = string(metadataBytes)
 74	}
 75	return response
 76}
 77
 78// AgentTool represents a tool that can be called by a language model.
 79// This matches the existing BaseTool interface pattern.
 80type AgentTool interface {
 81	Info() ToolInfo
 82	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 83}
 84
 85// NewTypedToolFunc creates a typed tool from a function with automatic schema generation.
 86// This is the recommended way to create tools.
 87func NewTypedToolFunc[TInput any](
 88	name string,
 89	description string,
 90	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
 91) AgentTool {
 92	var input TInput
 93	schema := generateSchema(reflect.TypeOf(input))
 94
 95	return &funcToolWrapper[TInput]{
 96		name:        name,
 97		description: description,
 98		fn:          fn,
 99		schema:      schema,
100	}
101}
102
103// funcToolWrapper wraps a function to implement the AgentTool interface.
104type funcToolWrapper[TInput any] struct {
105	name        string
106	description string
107	fn          func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
108	schema      Schema
109}
110
111func (w *funcToolWrapper[TInput]) Info() ToolInfo {
112	return ToolInfo{
113		Name:        w.name,
114		Description: w.description,
115		Parameters:  schemaToParameters(w.schema),
116		Required:    w.schema.Required,
117	}
118}
119
120func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
121	var input TInput
122	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
123		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
124	}
125
126	return w.fn(ctx, input, params)
127}
128
129// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
130func schemaToParameters(schema Schema) map[string]any {
131	if schema.Type != "object" || schema.Properties == nil {
132		return map[string]any{}
133	}
134
135	params := make(map[string]any)
136	for name, propSchema := range schema.Properties {
137		param := map[string]any{
138			"type": propSchema.Type,
139		}
140
141		if propSchema.Description != "" {
142			param["description"] = propSchema.Description
143		}
144
145		if len(propSchema.Enum) > 0 {
146			param["enum"] = propSchema.Enum
147		}
148
149		if propSchema.Format != "" {
150			param["format"] = propSchema.Format
151		}
152
153		if propSchema.Minimum != nil {
154			param["minimum"] = *propSchema.Minimum
155		}
156
157		if propSchema.Maximum != nil {
158			param["maximum"] = *propSchema.Maximum
159		}
160
161		if propSchema.MinLength != nil {
162			param["minLength"] = *propSchema.MinLength
163		}
164
165		if propSchema.MaxLength != nil {
166			param["maxLength"] = *propSchema.MaxLength
167		}
168
169		if propSchema.Items != nil {
170			param["items"] = schemaToParameters(*propSchema.Items)
171		}
172
173		params[name] = param
174	}
175
176	return params
177}
178
179// generateSchema automatically generates a JSON schema from a Go type.
180func generateSchema(t reflect.Type) Schema {
181	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
182}
183
184func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
185	// Handle pointers
186	if t.Kind() == reflect.Pointer {
187		t = t.Elem()
188	}
189
190	// Prevent infinite recursion
191	if visited[t] {
192		return Schema{Type: "object"}
193	}
194	visited[t] = true
195	defer delete(visited, t)
196
197	switch t.Kind() {
198	case reflect.String:
199		return Schema{Type: "string"}
200	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
201		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
202		return Schema{Type: "integer"}
203	case reflect.Float32, reflect.Float64:
204		return Schema{Type: "number"}
205	case reflect.Bool:
206		return Schema{Type: "boolean"}
207	case reflect.Slice, reflect.Array:
208		itemSchema := generateSchemaRecursive(t.Elem(), visited)
209		return Schema{
210			Type:  "array",
211			Items: &itemSchema,
212		}
213	case reflect.Map:
214		if t.Key().Kind() == reflect.String {
215			valueSchema := generateSchemaRecursive(t.Elem(), visited)
216			return Schema{
217				Type: "object",
218				Properties: map[string]*Schema{
219					"*": &valueSchema,
220				},
221			}
222		}
223		return Schema{Type: "object"}
224	case reflect.Struct:
225		schema := Schema{
226			Type:       "object",
227			Properties: make(map[string]*Schema),
228		}
229
230		for i := range t.NumField() {
231			field := t.Field(i)
232
233			// Skip unexported fields
234			if !field.IsExported() {
235				continue
236			}
237
238			jsonTag := field.Tag.Get("json")
239			if jsonTag == "-" {
240				continue
241			}
242
243			fieldName := field.Name
244			required := true
245
246			// Parse JSON tag
247			if jsonTag != "" {
248				parts := strings.Split(jsonTag, ",")
249				if parts[0] != "" {
250					fieldName = parts[0]
251				}
252
253				// Check for omitempty
254				for _, part := range parts[1:] {
255					if part == "omitempty" {
256						required = false
257						break
258					}
259				}
260			} else {
261				// Convert field name to snake_case for JSON
262				fieldName = toSnakeCase(fieldName)
263			}
264
265			fieldSchema := generateSchemaRecursive(field.Type, visited)
266
267			// Add description from struct tag if available
268			if desc := field.Tag.Get("description"); desc != "" {
269				fieldSchema.Description = desc
270			}
271
272			// Add enum values from struct tag if available
273			if enumTag := field.Tag.Get("enum"); enumTag != "" {
274				enumValues := strings.Split(enumTag, ",")
275				fieldSchema.Enum = make([]any, len(enumValues))
276				for i, v := range enumValues {
277					fieldSchema.Enum[i] = strings.TrimSpace(v)
278				}
279			}
280
281			schema.Properties[fieldName] = &fieldSchema
282
283			if required {
284				schema.Required = append(schema.Required, fieldName)
285			}
286		}
287
288		return schema
289	case reflect.Interface:
290		return Schema{Type: "object"}
291	default:
292		return Schema{Type: "object"}
293	}
294}
295
296// toSnakeCase converts PascalCase to snake_case.
297func toSnakeCase(s string) string {
298	var result strings.Builder
299	for i, r := range s {
300		if i > 0 && r >= 'A' && r <= 'Z' {
301			result.WriteByte('_')
302		}
303		result.WriteRune(r)
304	}
305	return strings.ToLower(result.String())
306}