1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "slices"
9 "strings"
10)
11
12// Schema represents a JSON schema for tool input validation.
13type Schema struct {
14 Type string `json:"type"`
15 Properties map[string]*Schema `json:"properties,omitempty"`
16 Required []string `json:"required,omitempty"`
17 Items *Schema `json:"items,omitempty"`
18 Description string `json:"description,omitempty"`
19 Enum []any `json:"enum,omitempty"`
20 Format string `json:"format,omitempty"`
21 Minimum *float64 `json:"minimum,omitempty"`
22 Maximum *float64 `json:"maximum,omitempty"`
23 MinLength *int `json:"minLength,omitempty"`
24 MaxLength *int `json:"maxLength,omitempty"`
25}
26
27// ToolInfo represents tool metadata, matching the existing pattern.
28type ToolInfo struct {
29 Name string
30 Description string
31 Parameters map[string]any
32 Required []string
33}
34
35// ToolCall represents a tool invocation, matching the existing pattern.
36type ToolCall struct {
37 ID string `json:"id"`
38 Name string `json:"name"`
39 Input string `json:"input"`
40}
41
42// ToolResponse represents the response from a tool execution, matching the existing pattern.
43type ToolResponse struct {
44 Type string `json:"type"`
45 Content string `json:"content"`
46 Metadata string `json:"metadata,omitempty"`
47 IsError bool `json:"is_error"`
48}
49
50// NewTextResponse creates a text response.
51func NewTextResponse(content string) ToolResponse {
52 return ToolResponse{
53 Type: "text",
54 Content: content,
55 }
56}
57
58// NewTextErrorResponse creates an error response.
59func NewTextErrorResponse(content string) ToolResponse {
60 return ToolResponse{
61 Type: "text",
62 Content: content,
63 IsError: true,
64 }
65}
66
67// WithResponseMetadata adds metadata to a response.
68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
69 if metadata != nil {
70 metadataBytes, err := json.Marshal(metadata)
71 if err != nil {
72 return response
73 }
74 response.Metadata = string(metadataBytes)
75 }
76 return response
77}
78
79// AgentTool represents a tool that can be called by a language model.
80// This matches the existing BaseTool interface pattern.
81type AgentTool interface {
82 Info() ToolInfo
83 Run(ctx context.Context, params ToolCall) (ToolResponse, error)
84}
85
86// NewAgentTool creates a typed tool from a function with automatic schema generation.
87// This is the recommended way to create tools.
88func NewAgentTool[TInput any](
89 name string,
90 description string,
91 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
92) AgentTool {
93 var input TInput
94 schema := generateSchema(reflect.TypeOf(input))
95
96 return &funcToolWrapper[TInput]{
97 name: name,
98 description: description,
99 fn: fn,
100 schema: schema,
101 }
102}
103
104// funcToolWrapper wraps a function to implement the AgentTool interface.
105type funcToolWrapper[TInput any] struct {
106 name string
107 description string
108 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
109 schema Schema
110}
111
112func (w *funcToolWrapper[TInput]) Info() ToolInfo {
113 if w.schema.Required == nil {
114 w.schema.Required = []string{}
115 }
116 return ToolInfo{
117 Name: w.name,
118 Description: w.description,
119 Parameters: schemaToParameters(w.schema),
120 Required: w.schema.Required,
121 }
122}
123
124func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
125 var input TInput
126 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
127 return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
128 }
129
130 return w.fn(ctx, input, params)
131}
132
133// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
134func schemaToParameters(schema Schema) map[string]any {
135 if schema.Type != "object" || schema.Properties == nil {
136 return map[string]any{}
137 }
138
139 params := make(map[string]any)
140 for name, propSchema := range schema.Properties {
141 param := map[string]any{
142 "type": propSchema.Type,
143 }
144
145 if propSchema.Description != "" {
146 param["description"] = propSchema.Description
147 }
148
149 if len(propSchema.Enum) > 0 {
150 param["enum"] = propSchema.Enum
151 }
152
153 if propSchema.Format != "" {
154 param["format"] = propSchema.Format
155 }
156
157 if propSchema.Minimum != nil {
158 param["minimum"] = *propSchema.Minimum
159 }
160
161 if propSchema.Maximum != nil {
162 param["maximum"] = *propSchema.Maximum
163 }
164
165 if propSchema.MinLength != nil {
166 param["minLength"] = *propSchema.MinLength
167 }
168
169 if propSchema.MaxLength != nil {
170 param["maxLength"] = *propSchema.MaxLength
171 }
172
173 if propSchema.Items != nil {
174 param["items"] = schemaToParameters(*propSchema.Items)
175 }
176
177 params[name] = param
178 }
179
180 return params
181}
182
183// generateSchema automatically generates a JSON schema from a Go type.
184func generateSchema(t reflect.Type) Schema {
185 return generateSchemaRecursive(t, make(map[reflect.Type]bool))
186}
187
188func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
189 // Handle pointers
190 if t.Kind() == reflect.Pointer {
191 t = t.Elem()
192 }
193
194 // Prevent infinite recursion
195 if visited[t] {
196 return Schema{Type: "object"}
197 }
198 visited[t] = true
199 defer delete(visited, t)
200
201 switch t.Kind() {
202 case reflect.String:
203 return Schema{Type: "string"}
204 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
205 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
206 return Schema{Type: "integer"}
207 case reflect.Float32, reflect.Float64:
208 return Schema{Type: "number"}
209 case reflect.Bool:
210 return Schema{Type: "boolean"}
211 case reflect.Slice, reflect.Array:
212 itemSchema := generateSchemaRecursive(t.Elem(), visited)
213 return Schema{
214 Type: "array",
215 Items: &itemSchema,
216 }
217 case reflect.Map:
218 if t.Key().Kind() == reflect.String {
219 valueSchema := generateSchemaRecursive(t.Elem(), visited)
220 return Schema{
221 Type: "object",
222 Properties: map[string]*Schema{
223 "*": &valueSchema,
224 },
225 }
226 }
227 return Schema{Type: "object"}
228 case reflect.Struct:
229 schema := Schema{
230 Type: "object",
231 Properties: make(map[string]*Schema),
232 }
233
234 for i := range t.NumField() {
235 field := t.Field(i)
236
237 // Skip unexported fields
238 if !field.IsExported() {
239 continue
240 }
241
242 jsonTag := field.Tag.Get("json")
243 if jsonTag == "-" {
244 continue
245 }
246
247 fieldName := field.Name
248 required := true
249
250 // Parse JSON tag
251 if jsonTag != "" {
252 parts := strings.Split(jsonTag, ",")
253 if parts[0] != "" {
254 fieldName = parts[0]
255 }
256
257 // Check for omitempty
258 if slices.Contains(parts[1:], "omitempty") {
259 required = false
260 }
261 } else {
262 // Convert field name to snake_case for JSON
263 fieldName = toSnakeCase(fieldName)
264 }
265
266 fieldSchema := generateSchemaRecursive(field.Type, visited)
267
268 // Add description from struct tag if available
269 if desc := field.Tag.Get("description"); desc != "" {
270 fieldSchema.Description = desc
271 }
272
273 // Add enum values from struct tag if available
274 if enumTag := field.Tag.Get("enum"); enumTag != "" {
275 enumValues := strings.Split(enumTag, ",")
276 fieldSchema.Enum = make([]any, len(enumValues))
277 for i, v := range enumValues {
278 fieldSchema.Enum[i] = strings.TrimSpace(v)
279 }
280 }
281
282 schema.Properties[fieldName] = &fieldSchema
283
284 if required {
285 schema.Required = append(schema.Required, fieldName)
286 }
287 }
288
289 return schema
290 case reflect.Interface:
291 return Schema{Type: "object"}
292 default:
293 return Schema{Type: "object"}
294 }
295}
296
297// toSnakeCase converts PascalCase to snake_case.
298func toSnakeCase(s string) string {
299 var result strings.Builder
300 for i, r := range s {
301 if i > 0 && r >= 'A' && r <= 'Z' {
302 result.WriteByte('_')
303 }
304 result.WriteRune(r)
305 }
306 return strings.ToLower(result.String())
307}