1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "slices"
9 "strings"
10)
11
12// Schema represents a JSON schema for tool input validation.
13type Schema struct {
14 Type string `json:"type"`
15 Properties map[string]*Schema `json:"properties,omitempty"`
16 Required []string `json:"required,omitempty"`
17 Items *Schema `json:"items,omitempty"`
18 Description string `json:"description,omitempty"`
19 Enum []any `json:"enum,omitempty"`
20 Format string `json:"format,omitempty"`
21 Minimum *float64 `json:"minimum,omitempty"`
22 Maximum *float64 `json:"maximum,omitempty"`
23 MinLength *int `json:"minLength,omitempty"`
24 MaxLength *int `json:"maxLength,omitempty"`
25}
26
27// ToolInfo represents tool metadata, matching the existing pattern.
28type ToolInfo struct {
29 Name string
30 Description string
31 Parameters map[string]any
32 Required []string
33}
34
35// ToolCall represents a tool invocation, matching the existing pattern.
36type ToolCall struct {
37 ID string `json:"id"`
38 Name string `json:"name"`
39 Input string `json:"input"`
40}
41
42// ToolResponse represents the response from a tool execution, matching the existing pattern.
43type ToolResponse struct {
44 Type string `json:"type"`
45 Content string `json:"content"`
46 Metadata string `json:"metadata,omitempty"`
47 IsError bool `json:"is_error"`
48}
49
50// NewTextResponse creates a text response.
51func NewTextResponse(content string) ToolResponse {
52 return ToolResponse{
53 Type: "text",
54 Content: content,
55 }
56}
57
58// NewTextErrorResponse creates an error response.
59func NewTextErrorResponse(content string) ToolResponse {
60 return ToolResponse{
61 Type: "text",
62 Content: content,
63 IsError: true,
64 }
65}
66
67// WithResponseMetadata adds metadata to a response.
68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
69 if metadata != nil {
70 metadataBytes, err := json.Marshal(metadata)
71 if err != nil {
72 return response
73 }
74 response.Metadata = string(metadataBytes)
75 }
76 return response
77}
78
79// AgentTool represents a tool that can be called by a language model.
80// This matches the existing BaseTool interface pattern.
81type AgentTool interface {
82 Info() ToolInfo
83 Run(ctx context.Context, params ToolCall) (ToolResponse, error)
84}
85
86// NewTypedToolFunc creates a typed tool from a function with automatic schema generation.
87// This is the recommended way to create tools.
88func NewTypedToolFunc[TInput any](
89 name string,
90 description string,
91 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
92) AgentTool {
93 var input TInput
94 schema := generateSchema(reflect.TypeOf(input))
95
96 return &funcToolWrapper[TInput]{
97 name: name,
98 description: description,
99 fn: fn,
100 schema: schema,
101 }
102}
103
104// funcToolWrapper wraps a function to implement the AgentTool interface.
105type funcToolWrapper[TInput any] struct {
106 name string
107 description string
108 fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
109 schema Schema
110}
111
112func (w *funcToolWrapper[TInput]) Info() ToolInfo {
113 return ToolInfo{
114 Name: w.name,
115 Description: w.description,
116 Parameters: schemaToParameters(w.schema),
117 Required: w.schema.Required,
118 }
119}
120
121func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
122 var input TInput
123 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
124 return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
125 }
126
127 return w.fn(ctx, input, params)
128}
129
130// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
131func schemaToParameters(schema Schema) map[string]any {
132 if schema.Type != "object" || schema.Properties == nil {
133 return map[string]any{}
134 }
135
136 params := make(map[string]any)
137 for name, propSchema := range schema.Properties {
138 param := map[string]any{
139 "type": propSchema.Type,
140 }
141
142 if propSchema.Description != "" {
143 param["description"] = propSchema.Description
144 }
145
146 if len(propSchema.Enum) > 0 {
147 param["enum"] = propSchema.Enum
148 }
149
150 if propSchema.Format != "" {
151 param["format"] = propSchema.Format
152 }
153
154 if propSchema.Minimum != nil {
155 param["minimum"] = *propSchema.Minimum
156 }
157
158 if propSchema.Maximum != nil {
159 param["maximum"] = *propSchema.Maximum
160 }
161
162 if propSchema.MinLength != nil {
163 param["minLength"] = *propSchema.MinLength
164 }
165
166 if propSchema.MaxLength != nil {
167 param["maxLength"] = *propSchema.MaxLength
168 }
169
170 if propSchema.Items != nil {
171 param["items"] = schemaToParameters(*propSchema.Items)
172 }
173
174 params[name] = param
175 }
176
177 return params
178}
179
180// generateSchema automatically generates a JSON schema from a Go type.
181func generateSchema(t reflect.Type) Schema {
182 return generateSchemaRecursive(t, make(map[reflect.Type]bool))
183}
184
185func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
186 // Handle pointers
187 if t.Kind() == reflect.Pointer {
188 t = t.Elem()
189 }
190
191 // Prevent infinite recursion
192 if visited[t] {
193 return Schema{Type: "object"}
194 }
195 visited[t] = true
196 defer delete(visited, t)
197
198 switch t.Kind() {
199 case reflect.String:
200 return Schema{Type: "string"}
201 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
202 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
203 return Schema{Type: "integer"}
204 case reflect.Float32, reflect.Float64:
205 return Schema{Type: "number"}
206 case reflect.Bool:
207 return Schema{Type: "boolean"}
208 case reflect.Slice, reflect.Array:
209 itemSchema := generateSchemaRecursive(t.Elem(), visited)
210 return Schema{
211 Type: "array",
212 Items: &itemSchema,
213 }
214 case reflect.Map:
215 if t.Key().Kind() == reflect.String {
216 valueSchema := generateSchemaRecursive(t.Elem(), visited)
217 return Schema{
218 Type: "object",
219 Properties: map[string]*Schema{
220 "*": &valueSchema,
221 },
222 }
223 }
224 return Schema{Type: "object"}
225 case reflect.Struct:
226 schema := Schema{
227 Type: "object",
228 Properties: make(map[string]*Schema),
229 }
230
231 for i := range t.NumField() {
232 field := t.Field(i)
233
234 // Skip unexported fields
235 if !field.IsExported() {
236 continue
237 }
238
239 jsonTag := field.Tag.Get("json")
240 if jsonTag == "-" {
241 continue
242 }
243
244 fieldName := field.Name
245 required := true
246
247 // Parse JSON tag
248 if jsonTag != "" {
249 parts := strings.Split(jsonTag, ",")
250 if parts[0] != "" {
251 fieldName = parts[0]
252 }
253
254 // Check for omitempty
255 if slices.Contains(parts[1:], "omitempty") {
256 required = false
257 }
258 } else {
259 // Convert field name to snake_case for JSON
260 fieldName = toSnakeCase(fieldName)
261 }
262
263 fieldSchema := generateSchemaRecursive(field.Type, visited)
264
265 // Add description from struct tag if available
266 if desc := field.Tag.Get("description"); desc != "" {
267 fieldSchema.Description = desc
268 }
269
270 // Add enum values from struct tag if available
271 if enumTag := field.Tag.Get("enum"); enumTag != "" {
272 enumValues := strings.Split(enumTag, ",")
273 fieldSchema.Enum = make([]any, len(enumValues))
274 for i, v := range enumValues {
275 fieldSchema.Enum[i] = strings.TrimSpace(v)
276 }
277 }
278
279 schema.Properties[fieldName] = &fieldSchema
280
281 if required {
282 schema.Required = append(schema.Required, fieldName)
283 }
284 }
285
286 return schema
287 case reflect.Interface:
288 return Schema{Type: "object"}
289 default:
290 return Schema{Type: "object"}
291 }
292}
293
294// toSnakeCase converts PascalCase to snake_case.
295func toSnakeCase(s string) string {
296 var result strings.Builder
297 for i, r := range s {
298 if i > 0 && r >= 'A' && r <= 'Z' {
299 result.WriteByte('_')
300 }
301 result.WriteRune(r)
302 }
303 return strings.ToLower(result.String())
304}