tool.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"slices"
  9	"strings"
 10)
 11
 12// Schema represents a JSON schema for tool input validation.
 13type Schema struct {
 14	Type        string             `json:"type"`
 15	Properties  map[string]*Schema `json:"properties,omitempty"`
 16	Required    []string           `json:"required,omitempty"`
 17	Items       *Schema            `json:"items,omitempty"`
 18	Description string             `json:"description,omitempty"`
 19	Enum        []any              `json:"enum,omitempty"`
 20	Format      string             `json:"format,omitempty"`
 21	Minimum     *float64           `json:"minimum,omitempty"`
 22	Maximum     *float64           `json:"maximum,omitempty"`
 23	MinLength   *int               `json:"minLength,omitempty"`
 24	MaxLength   *int               `json:"maxLength,omitempty"`
 25}
 26
 27// ToolInfo represents tool metadata, matching the existing pattern.
 28type ToolInfo struct {
 29	Name        string
 30	Description string
 31	Parameters  map[string]any
 32	Required    []string
 33}
 34
 35// ToolCall represents a tool invocation, matching the existing pattern.
 36type ToolCall struct {
 37	ID    string `json:"id"`
 38	Name  string `json:"name"`
 39	Input string `json:"input"`
 40}
 41
 42// ToolResponse represents the response from a tool execution, matching the existing pattern.
 43type ToolResponse struct {
 44	Type     string `json:"type"`
 45	Content  string `json:"content"`
 46	Metadata string `json:"metadata,omitempty"`
 47	IsError  bool   `json:"is_error"`
 48}
 49
 50// NewTextResponse creates a text response.
 51func NewTextResponse(content string) ToolResponse {
 52	return ToolResponse{
 53		Type:    "text",
 54		Content: content,
 55	}
 56}
 57
 58// NewTextErrorResponse creates an error response.
 59func NewTextErrorResponse(content string) ToolResponse {
 60	return ToolResponse{
 61		Type:    "text",
 62		Content: content,
 63		IsError: true,
 64	}
 65}
 66
 67// WithResponseMetadata adds metadata to a response.
 68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 69	if metadata != nil {
 70		metadataBytes, err := json.Marshal(metadata)
 71		if err != nil {
 72			return response
 73		}
 74		response.Metadata = string(metadataBytes)
 75	}
 76	return response
 77}
 78
 79// AgentTool represents a tool that can be called by a language model.
 80// This matches the existing BaseTool interface pattern.
 81type AgentTool interface {
 82	Info() ToolInfo
 83	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 84	ProviderOptions() ProviderOptions
 85	SetProviderOptions(opts ProviderOptions)
 86}
 87
 88// NewAgentTool creates a typed tool from a function with automatic schema generation.
 89// This is the recommended way to create tools.
 90func NewAgentTool[TInput any](
 91	name string,
 92	description string,
 93	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
 94) AgentTool {
 95	var input TInput
 96	schema := generateSchema(reflect.TypeOf(input))
 97
 98	return &funcToolWrapper[TInput]{
 99		name:        name,
100		description: description,
101		fn:          fn,
102		schema:      schema,
103	}
104}
105
106// funcToolWrapper wraps a function to implement the AgentTool interface.
107type funcToolWrapper[TInput any] struct {
108	name            string
109	description     string
110	fn              func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
111	schema          Schema
112	providerOptions ProviderOptions
113}
114
115func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) {
116	w.providerOptions = opts
117}
118
119func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions {
120	return w.providerOptions
121}
122
123func (w *funcToolWrapper[TInput]) Info() ToolInfo {
124	if w.schema.Required == nil {
125		w.schema.Required = []string{}
126	}
127	return ToolInfo{
128		Name:        w.name,
129		Description: w.description,
130		Parameters:  schemaToParameters(w.schema),
131		Required:    w.schema.Required,
132	}
133}
134
135func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
136	var input TInput
137	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
138		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
139	}
140
141	return w.fn(ctx, input, params)
142}
143
144// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
145func schemaToParameters(schema Schema) map[string]any {
146	if schema.Type != "object" || schema.Properties == nil {
147		return map[string]any{}
148	}
149
150	params := make(map[string]any)
151	for name, propSchema := range schema.Properties {
152		param := map[string]any{
153			"type": propSchema.Type,
154		}
155
156		if propSchema.Description != "" {
157			param["description"] = propSchema.Description
158		}
159
160		if len(propSchema.Enum) > 0 {
161			param["enum"] = propSchema.Enum
162		}
163
164		if propSchema.Format != "" {
165			param["format"] = propSchema.Format
166		}
167
168		if propSchema.Minimum != nil {
169			param["minimum"] = *propSchema.Minimum
170		}
171
172		if propSchema.Maximum != nil {
173			param["maximum"] = *propSchema.Maximum
174		}
175
176		if propSchema.MinLength != nil {
177			param["minLength"] = *propSchema.MinLength
178		}
179
180		if propSchema.MaxLength != nil {
181			param["maxLength"] = *propSchema.MaxLength
182		}
183
184		if propSchema.Items != nil {
185			param["items"] = schemaToParameters(*propSchema.Items)
186		}
187
188		params[name] = param
189	}
190
191	return params
192}
193
194// generateSchema automatically generates a JSON schema from a Go type.
195func generateSchema(t reflect.Type) Schema {
196	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
197}
198
199func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
200	// Handle pointers
201	if t.Kind() == reflect.Pointer {
202		t = t.Elem()
203	}
204
205	// Prevent infinite recursion
206	if visited[t] {
207		return Schema{Type: "object"}
208	}
209	visited[t] = true
210	defer delete(visited, t)
211
212	switch t.Kind() {
213	case reflect.String:
214		return Schema{Type: "string"}
215	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
216		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
217		return Schema{Type: "integer"}
218	case reflect.Float32, reflect.Float64:
219		return Schema{Type: "number"}
220	case reflect.Bool:
221		return Schema{Type: "boolean"}
222	case reflect.Slice, reflect.Array:
223		itemSchema := generateSchemaRecursive(t.Elem(), visited)
224		return Schema{
225			Type:  "array",
226			Items: &itemSchema,
227		}
228	case reflect.Map:
229		if t.Key().Kind() == reflect.String {
230			valueSchema := generateSchemaRecursive(t.Elem(), visited)
231			return Schema{
232				Type: "object",
233				Properties: map[string]*Schema{
234					"*": &valueSchema,
235				},
236			}
237		}
238		return Schema{Type: "object"}
239	case reflect.Struct:
240		schema := Schema{
241			Type:       "object",
242			Properties: make(map[string]*Schema),
243		}
244
245		for i := range t.NumField() {
246			field := t.Field(i)
247
248			// Skip unexported fields
249			if !field.IsExported() {
250				continue
251			}
252
253			jsonTag := field.Tag.Get("json")
254			if jsonTag == "-" {
255				continue
256			}
257
258			fieldName := field.Name
259			required := true
260
261			// Parse JSON tag
262			if jsonTag != "" {
263				parts := strings.Split(jsonTag, ",")
264				if parts[0] != "" {
265					fieldName = parts[0]
266				}
267
268				// Check for omitempty
269				if slices.Contains(parts[1:], "omitempty") {
270					required = false
271				}
272			} else {
273				// Convert field name to snake_case for JSON
274				fieldName = toSnakeCase(fieldName)
275			}
276
277			fieldSchema := generateSchemaRecursive(field.Type, visited)
278
279			// Add description from struct tag if available
280			if desc := field.Tag.Get("description"); desc != "" {
281				fieldSchema.Description = desc
282			}
283
284			// Add enum values from struct tag if available
285			if enumTag := field.Tag.Get("enum"); enumTag != "" {
286				enumValues := strings.Split(enumTag, ",")
287				fieldSchema.Enum = make([]any, len(enumValues))
288				for i, v := range enumValues {
289					fieldSchema.Enum[i] = strings.TrimSpace(v)
290				}
291			}
292
293			schema.Properties[fieldName] = &fieldSchema
294
295			if required {
296				schema.Required = append(schema.Required, fieldName)
297			}
298		}
299
300		return schema
301	case reflect.Interface:
302		return Schema{Type: "object"}
303	default:
304		return Schema{Type: "object"}
305	}
306}
307
308// toSnakeCase converts PascalCase to snake_case.
309func toSnakeCase(s string) string {
310	var result strings.Builder
311	for i, r := range s {
312		if i > 0 && r >= 'A' && r <= 'Z' {
313			result.WriteByte('_')
314		}
315		result.WriteRune(r)
316	}
317	return strings.ToLower(result.String())
318}