1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "slices"
9 "strings"
10)
11
12// Schema represents a JSON schema for tool input validation.
13type Schema struct {
14 Type string `json:"type"`
15 Properties map[string]*Schema `json:"properties,omitempty"`
16 Required []string `json:"required,omitempty"`
17 Items *Schema `json:"items,omitempty"`
18 Description string `json:"description,omitempty"`
19 Enum []any `json:"enum,omitempty"`
20 Format string `json:"format,omitempty"`
21 Minimum *float64 `json:"minimum,omitempty"`
22 Maximum *float64 `json:"maximum,omitempty"`
23 MinLength *int `json:"minLength,omitempty"`
24 MaxLength *int `json:"maxLength,omitempty"`
25}
26
27// ToolInfo represents tool metadata, matching the existing pattern.
28type ToolInfo struct {
29 Name string
30 Description string
31 Parameters map[string]any
32 Required []string
33}
34
35// ToolCall represents a tool invocation, matching the existing pattern.
36type ToolCall struct {
37 ID string `json:"id"`
38 Name string `json:"name"`
39 Input string `json:"input"`
40}
41
42// ToolResponse represents the response from a tool execution, matching the existing pattern.
43type ToolResponse struct {
44 Type string `json:"type"`
45 Content string `json:"content"`
46 Metadata string `json:"metadata,omitempty"`
47 IsError bool `json:"is_error"`
48}
49
50// NewTextResponse creates a text response.
51func NewTextResponse(content string) ToolResponse {
52 return ToolResponse{
53 Type: "text",
54 Content: content,
55 }
56}
57
58// NewTextErrorResponse creates an error response.
59func NewTextErrorResponse(content string) ToolResponse {
60 return ToolResponse{
61 Type: "text",
62 Content: content,
63 IsError: true,
64 }
65}
66
67// WithResponseMetadata adds metadata to a response.
68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
69 if metadata != nil {
70 metadataBytes, err := json.Marshal(metadata)
71 if err != nil {
72 return response
73 }
74 response.Metadata = string(metadataBytes)
75 }
76 return response
77}
78
79// AgentTool represents a tool that can be called by a language model.
80// This matches the existing BaseTool interface pattern.
81type AgentTool interface {
82 Info() ToolInfo
83 Run(ctx context.Context, params ToolCall) (ToolResponse, error)
84 ProviderOptions() ProviderOptions
85 SetProviderOptions(opts ProviderOptions)
86}
87
88// NewAgentTool creates a typed tool from a function with automatic schema generation.
89// This is the recommended way to create tools.
90func NewAgentTool[TInput any](
91 name string,
92 description string,
93 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
94) AgentTool {
95 var input TInput
96 schema := generateSchema(reflect.TypeOf(input))
97
98 return &funcToolWrapper[TInput]{
99 name: name,
100 description: description,
101 fn: fn,
102 schema: schema,
103 }
104}
105
106// funcToolWrapper wraps a function to implement the AgentTool interface.
107type funcToolWrapper[TInput any] struct {
108 name string
109 description string
110 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
111 schema Schema
112 providerOptions ProviderOptions
113}
114
115func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) {
116 w.providerOptions = opts
117}
118
119func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions {
120 return w.providerOptions
121}
122
123func (w *funcToolWrapper[TInput]) Info() ToolInfo {
124 if w.schema.Required == nil {
125 w.schema.Required = []string{}
126 }
127 return ToolInfo{
128 Name: w.name,
129 Description: w.description,
130 Parameters: schemaToParameters(w.schema),
131 Required: w.schema.Required,
132 }
133}
134
135func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
136 var input TInput
137 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
138 return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
139 }
140
141 return w.fn(ctx, input, params)
142}
143
144// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
145func schemaToParameters(schema Schema) map[string]any {
146 if schema.Type != "object" || schema.Properties == nil {
147 return map[string]any{}
148 }
149
150 params := make(map[string]any)
151 for name, propSchema := range schema.Properties {
152 param := map[string]any{
153 "type": propSchema.Type,
154 }
155
156 if propSchema.Description != "" {
157 param["description"] = propSchema.Description
158 }
159
160 if len(propSchema.Enum) > 0 {
161 param["enum"] = propSchema.Enum
162 }
163
164 if propSchema.Format != "" {
165 param["format"] = propSchema.Format
166 }
167
168 if propSchema.Minimum != nil {
169 param["minimum"] = *propSchema.Minimum
170 }
171
172 if propSchema.Maximum != nil {
173 param["maximum"] = *propSchema.Maximum
174 }
175
176 if propSchema.MinLength != nil {
177 param["minLength"] = *propSchema.MinLength
178 }
179
180 if propSchema.MaxLength != nil {
181 param["maxLength"] = *propSchema.MaxLength
182 }
183
184 if propSchema.Items != nil {
185 param["items"] = schemaToParameters(*propSchema.Items)
186 }
187
188 params[name] = param
189 }
190
191 return params
192}
193
194// generateSchema automatically generates a JSON schema from a Go type.
195func generateSchema(t reflect.Type) Schema {
196 return generateSchemaRecursive(t, make(map[reflect.Type]bool))
197}
198
199func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
200 // Handle pointers
201 if t.Kind() == reflect.Pointer {
202 t = t.Elem()
203 }
204
205 // Prevent infinite recursion
206 if visited[t] {
207 return Schema{Type: "object"}
208 }
209 visited[t] = true
210 defer delete(visited, t)
211
212 switch t.Kind() {
213 case reflect.String:
214 return Schema{Type: "string"}
215 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
216 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
217 return Schema{Type: "integer"}
218 case reflect.Float32, reflect.Float64:
219 return Schema{Type: "number"}
220 case reflect.Bool:
221 return Schema{Type: "boolean"}
222 case reflect.Slice, reflect.Array:
223 itemSchema := generateSchemaRecursive(t.Elem(), visited)
224 return Schema{
225 Type: "array",
226 Items: &itemSchema,
227 }
228 case reflect.Map:
229 if t.Key().Kind() == reflect.String {
230 valueSchema := generateSchemaRecursive(t.Elem(), visited)
231 return Schema{
232 Type: "object",
233 Properties: map[string]*Schema{
234 "*": &valueSchema,
235 },
236 }
237 }
238 return Schema{Type: "object"}
239 case reflect.Struct:
240 schema := Schema{
241 Type: "object",
242 Properties: make(map[string]*Schema),
243 }
244
245 for i := range t.NumField() {
246 field := t.Field(i)
247
248 // Skip unexported fields
249 if !field.IsExported() {
250 continue
251 }
252
253 jsonTag := field.Tag.Get("json")
254 if jsonTag == "-" {
255 continue
256 }
257
258 fieldName := field.Name
259 required := true
260
261 // Parse JSON tag
262 if jsonTag != "" {
263 parts := strings.Split(jsonTag, ",")
264 if parts[0] != "" {
265 fieldName = parts[0]
266 }
267
268 // Check for omitempty
269 if slices.Contains(parts[1:], "omitempty") {
270 required = false
271 }
272 } else {
273 // Convert field name to snake_case for JSON
274 fieldName = toSnakeCase(fieldName)
275 }
276
277 fieldSchema := generateSchemaRecursive(field.Type, visited)
278
279 // Add description from struct tag if available
280 if desc := field.Tag.Get("description"); desc != "" {
281 fieldSchema.Description = desc
282 }
283
284 // Add enum values from struct tag if available
285 if enumTag := field.Tag.Get("enum"); enumTag != "" {
286 enumValues := strings.Split(enumTag, ",")
287 fieldSchema.Enum = make([]any, len(enumValues))
288 for i, v := range enumValues {
289 fieldSchema.Enum[i] = strings.TrimSpace(v)
290 }
291 }
292
293 schema.Properties[fieldName] = &fieldSchema
294
295 if required {
296 schema.Required = append(schema.Required, fieldName)
297 }
298 }
299
300 return schema
301 case reflect.Interface:
302 return Schema{Type: "object"}
303 default:
304 return Schema{Type: "object"}
305 }
306}
307
308// toSnakeCase converts PascalCase to snake_case.
309func toSnakeCase(s string) string {
310 var result strings.Builder
311 for i, r := range s {
312 if i > 0 && r >= 'A' && r <= 'Z' {
313 result.WriteByte('_')
314 }
315 result.WriteRune(r)
316 }
317 return strings.ToLower(result.String())
318}