1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"slices"
  9	"strings"
 10)
 11
 12// Schema represents a JSON schema for tool input validation.
 13type Schema struct {
 14	Type        string             `json:"type"`
 15	Properties  map[string]*Schema `json:"properties,omitempty"`
 16	Required    []string           `json:"required,omitempty"`
 17	Items       *Schema            `json:"items,omitempty"`
 18	Description string             `json:"description,omitempty"`
 19	Enum        []any              `json:"enum,omitempty"`
 20	Format      string             `json:"format,omitempty"`
 21	Minimum     *float64           `json:"minimum,omitempty"`
 22	Maximum     *float64           `json:"maximum,omitempty"`
 23	MinLength   *int               `json:"minLength,omitempty"`
 24	MaxLength   *int               `json:"maxLength,omitempty"`
 25}
 26
 27// ToolInfo represents tool metadata, matching the existing pattern.
 28type ToolInfo struct {
 29	Name        string
 30	Description string
 31	Parameters  map[string]any
 32	Required    []string
 33}
 34
 35// ToolCall represents a tool invocation, matching the existing pattern.
 36type ToolCall struct {
 37	ID    string `json:"id"`
 38	Name  string `json:"name"`
 39	Input string `json:"input"`
 40}
 41
 42// ToolResponse represents the response from a tool execution, matching the existing pattern.
 43type ToolResponse struct {
 44	Type     string `json:"type"`
 45	Content  string `json:"content"`
 46	Metadata string `json:"metadata,omitempty"`
 47	IsError  bool   `json:"is_error"`
 48}
 49
 50// NewTextResponse creates a text response.
 51func NewTextResponse(content string) ToolResponse {
 52	return ToolResponse{
 53		Type:    "text",
 54		Content: content,
 55	}
 56}
 57
 58// NewTextErrorResponse creates an error response.
 59func NewTextErrorResponse(content string) ToolResponse {
 60	return ToolResponse{
 61		Type:    "text",
 62		Content: content,
 63		IsError: true,
 64	}
 65}
 66
 67// WithResponseMetadata adds metadata to a response.
 68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 69	if metadata != nil {
 70		metadataBytes, err := json.Marshal(metadata)
 71		if err != nil {
 72			return response
 73		}
 74		response.Metadata = string(metadataBytes)
 75	}
 76	return response
 77}
 78
 79// AgentTool represents a tool that can be called by a language model.
 80// This matches the existing BaseTool interface pattern.
 81type AgentTool interface {
 82	Info() ToolInfo
 83	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 84}
 85
 86// NewAgentTool creates a typed tool from a function with automatic schema generation.
 87// This is the recommended way to create tools.
 88func NewAgentTool[TInput any](
 89	name string,
 90	description string,
 91	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
 92) AgentTool {
 93	var input TInput
 94	schema := generateSchema(reflect.TypeOf(input))
 95
 96	return &funcToolWrapper[TInput]{
 97		name:        name,
 98		description: description,
 99		fn:          fn,
100		schema:      schema,
101	}
102}
103
104// funcToolWrapper wraps a function to implement the AgentTool interface.
105type funcToolWrapper[TInput any] struct {
106	name        string
107	description string
108	fn          func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
109	schema      Schema
110}
111
112func (w *funcToolWrapper[TInput]) Info() ToolInfo {
113	if w.schema.Required == nil {
114		w.schema.Required = []string{}
115	}
116	return ToolInfo{
117		Name:        w.name,
118		Description: w.description,
119		Parameters:  schemaToParameters(w.schema),
120		Required:    w.schema.Required,
121	}
122}
123
124func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
125	var input TInput
126	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
127		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
128	}
129
130	return w.fn(ctx, input, params)
131}
132
133// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
134func schemaToParameters(schema Schema) map[string]any {
135	if schema.Type != "object" || schema.Properties == nil {
136		return map[string]any{}
137	}
138
139	params := make(map[string]any)
140	for name, propSchema := range schema.Properties {
141		param := map[string]any{
142			"type": propSchema.Type,
143		}
144
145		if propSchema.Description != "" {
146			param["description"] = propSchema.Description
147		}
148
149		if len(propSchema.Enum) > 0 {
150			param["enum"] = propSchema.Enum
151		}
152
153		if propSchema.Format != "" {
154			param["format"] = propSchema.Format
155		}
156
157		if propSchema.Minimum != nil {
158			param["minimum"] = *propSchema.Minimum
159		}
160
161		if propSchema.Maximum != nil {
162			param["maximum"] = *propSchema.Maximum
163		}
164
165		if propSchema.MinLength != nil {
166			param["minLength"] = *propSchema.MinLength
167		}
168
169		if propSchema.MaxLength != nil {
170			param["maxLength"] = *propSchema.MaxLength
171		}
172
173		if propSchema.Items != nil {
174			param["items"] = schemaToParameters(*propSchema.Items)
175		}
176
177		params[name] = param
178	}
179
180	return params
181}
182
183// generateSchema automatically generates a JSON schema from a Go type.
184func generateSchema(t reflect.Type) Schema {
185	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
186}
187
188func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
189	// Handle pointers
190	if t.Kind() == reflect.Pointer {
191		t = t.Elem()
192	}
193
194	// Prevent infinite recursion
195	if visited[t] {
196		return Schema{Type: "object"}
197	}
198	visited[t] = true
199	defer delete(visited, t)
200
201	switch t.Kind() {
202	case reflect.String:
203		return Schema{Type: "string"}
204	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
205		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
206		return Schema{Type: "integer"}
207	case reflect.Float32, reflect.Float64:
208		return Schema{Type: "number"}
209	case reflect.Bool:
210		return Schema{Type: "boolean"}
211	case reflect.Slice, reflect.Array:
212		itemSchema := generateSchemaRecursive(t.Elem(), visited)
213		return Schema{
214			Type:  "array",
215			Items: &itemSchema,
216		}
217	case reflect.Map:
218		if t.Key().Kind() == reflect.String {
219			valueSchema := generateSchemaRecursive(t.Elem(), visited)
220			return Schema{
221				Type: "object",
222				Properties: map[string]*Schema{
223					"*": &valueSchema,
224				},
225			}
226		}
227		return Schema{Type: "object"}
228	case reflect.Struct:
229		schema := Schema{
230			Type:       "object",
231			Properties: make(map[string]*Schema),
232		}
233
234		for i := range t.NumField() {
235			field := t.Field(i)
236
237			// Skip unexported fields
238			if !field.IsExported() {
239				continue
240			}
241
242			jsonTag := field.Tag.Get("json")
243			if jsonTag == "-" {
244				continue
245			}
246
247			fieldName := field.Name
248			required := true
249
250			// Parse JSON tag
251			if jsonTag != "" {
252				parts := strings.Split(jsonTag, ",")
253				if parts[0] != "" {
254					fieldName = parts[0]
255				}
256
257				// Check for omitempty
258				if slices.Contains(parts[1:], "omitempty") {
259					required = false
260				}
261			} else {
262				// Convert field name to snake_case for JSON
263				fieldName = toSnakeCase(fieldName)
264			}
265
266			fieldSchema := generateSchemaRecursive(field.Type, visited)
267
268			// Add description from struct tag if available
269			if desc := field.Tag.Get("description"); desc != "" {
270				fieldSchema.Description = desc
271			}
272
273			// Add enum values from struct tag if available
274			if enumTag := field.Tag.Get("enum"); enumTag != "" {
275				enumValues := strings.Split(enumTag, ",")
276				fieldSchema.Enum = make([]any, len(enumValues))
277				for i, v := range enumValues {
278					fieldSchema.Enum[i] = strings.TrimSpace(v)
279				}
280			}
281
282			schema.Properties[fieldName] = &fieldSchema
283
284			if required {
285				schema.Required = append(schema.Required, fieldName)
286			}
287		}
288
289		return schema
290	case reflect.Interface:
291		return Schema{Type: "object"}
292	default:
293		return Schema{Type: "object"}
294	}
295}
296
297// toSnakeCase converts PascalCase to snake_case.
298func toSnakeCase(s string) string {
299	var result strings.Builder
300	for i, r := range s {
301		if i > 0 && r >= 'A' && r <= 'Z' {
302			result.WriteByte('_')
303		}
304		result.WriteRune(r)
305	}
306	return strings.ToLower(result.String())
307}