tool.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"slices"
  9	"strings"
 10)
 11
 12// Schema represents a JSON schema for tool input validation.
 13type Schema struct {
 14	Type        string             `json:"type"`
 15	Properties  map[string]*Schema `json:"properties,omitempty"`
 16	Required    []string           `json:"required,omitempty"`
 17	Items       *Schema            `json:"items,omitempty"`
 18	Description string             `json:"description,omitempty"`
 19	Enum        []any              `json:"enum,omitempty"`
 20	Format      string             `json:"format,omitempty"`
 21	Minimum     *float64           `json:"minimum,omitempty"`
 22	Maximum     *float64           `json:"maximum,omitempty"`
 23	MinLength   *int               `json:"minLength,omitempty"`
 24	MaxLength   *int               `json:"maxLength,omitempty"`
 25}
 26
 27// ToolInfo represents tool metadata, matching the existing pattern.
 28type ToolInfo struct {
 29	Name        string
 30	Description string
 31	Parameters  map[string]any
 32	Required    []string
 33}
 34
 35// ToolCall represents a tool invocation, matching the existing pattern.
 36type ToolCall struct {
 37	ID    string `json:"id"`
 38	Name  string `json:"name"`
 39	Input string `json:"input"`
 40}
 41
 42// ToolResponse represents the response from a tool execution, matching the existing pattern.
 43type ToolResponse struct {
 44	Type     string `json:"type"`
 45	Content  string `json:"content"`
 46	Metadata string `json:"metadata,omitempty"`
 47	IsError  bool   `json:"is_error"`
 48}
 49
 50// NewTextResponse creates a text response.
 51func NewTextResponse(content string) ToolResponse {
 52	return ToolResponse{
 53		Type:    "text",
 54		Content: content,
 55	}
 56}
 57
 58// NewTextErrorResponse creates an error response.
 59func NewTextErrorResponse(content string) ToolResponse {
 60	return ToolResponse{
 61		Type:    "text",
 62		Content: content,
 63		IsError: true,
 64	}
 65}
 66
 67// WithResponseMetadata adds metadata to a response.
 68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 69	if metadata != nil {
 70		metadataBytes, err := json.Marshal(metadata)
 71		if err != nil {
 72			return response
 73		}
 74		response.Metadata = string(metadataBytes)
 75	}
 76	return response
 77}
 78
 79// AgentTool represents a tool that can be called by a language model.
 80// This matches the existing BaseTool interface pattern.
 81type AgentTool interface {
 82	Info() ToolInfo
 83	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 84}
 85
 86// NewTypedToolFunc creates a typed tool from a function with automatic schema generation.
 87// This is the recommended way to create tools.
 88func NewTypedToolFunc[TInput any](
 89	name string,
 90	description string,
 91	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
 92) AgentTool {
 93	var input TInput
 94	schema := generateSchema(reflect.TypeOf(input))
 95
 96	return &funcToolWrapper[TInput]{
 97		name:        name,
 98		description: description,
 99		fn:          fn,
100		schema:      schema,
101	}
102}
103
104// funcToolWrapper wraps a function to implement the AgentTool interface.
105type funcToolWrapper[TInput any] struct {
106	name        string
107	description string
108	fn          func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
109	schema      Schema
110}
111
112func (w *funcToolWrapper[TInput]) Info() ToolInfo {
113	return ToolInfo{
114		Name:        w.name,
115		Description: w.description,
116		Parameters:  schemaToParameters(w.schema),
117		Required:    w.schema.Required,
118	}
119}
120
121func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
122	var input TInput
123	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
124		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
125	}
126
127	return w.fn(ctx, input, params)
128}
129
130// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
131func schemaToParameters(schema Schema) map[string]any {
132	if schema.Type != "object" || schema.Properties == nil {
133		return map[string]any{}
134	}
135
136	params := make(map[string]any)
137	for name, propSchema := range schema.Properties {
138		param := map[string]any{
139			"type": propSchema.Type,
140		}
141
142		if propSchema.Description != "" {
143			param["description"] = propSchema.Description
144		}
145
146		if len(propSchema.Enum) > 0 {
147			param["enum"] = propSchema.Enum
148		}
149
150		if propSchema.Format != "" {
151			param["format"] = propSchema.Format
152		}
153
154		if propSchema.Minimum != nil {
155			param["minimum"] = *propSchema.Minimum
156		}
157
158		if propSchema.Maximum != nil {
159			param["maximum"] = *propSchema.Maximum
160		}
161
162		if propSchema.MinLength != nil {
163			param["minLength"] = *propSchema.MinLength
164		}
165
166		if propSchema.MaxLength != nil {
167			param["maxLength"] = *propSchema.MaxLength
168		}
169
170		if propSchema.Items != nil {
171			param["items"] = schemaToParameters(*propSchema.Items)
172		}
173
174		params[name] = param
175	}
176
177	return params
178}
179
180// generateSchema automatically generates a JSON schema from a Go type.
181func generateSchema(t reflect.Type) Schema {
182	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
183}
184
185func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
186	// Handle pointers
187	if t.Kind() == reflect.Pointer {
188		t = t.Elem()
189	}
190
191	// Prevent infinite recursion
192	if visited[t] {
193		return Schema{Type: "object"}
194	}
195	visited[t] = true
196	defer delete(visited, t)
197
198	switch t.Kind() {
199	case reflect.String:
200		return Schema{Type: "string"}
201	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
202		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
203		return Schema{Type: "integer"}
204	case reflect.Float32, reflect.Float64:
205		return Schema{Type: "number"}
206	case reflect.Bool:
207		return Schema{Type: "boolean"}
208	case reflect.Slice, reflect.Array:
209		itemSchema := generateSchemaRecursive(t.Elem(), visited)
210		return Schema{
211			Type:  "array",
212			Items: &itemSchema,
213		}
214	case reflect.Map:
215		if t.Key().Kind() == reflect.String {
216			valueSchema := generateSchemaRecursive(t.Elem(), visited)
217			return Schema{
218				Type: "object",
219				Properties: map[string]*Schema{
220					"*": &valueSchema,
221				},
222			}
223		}
224		return Schema{Type: "object"}
225	case reflect.Struct:
226		schema := Schema{
227			Type:       "object",
228			Properties: make(map[string]*Schema),
229		}
230
231		for i := range t.NumField() {
232			field := t.Field(i)
233
234			// Skip unexported fields
235			if !field.IsExported() {
236				continue
237			}
238
239			jsonTag := field.Tag.Get("json")
240			if jsonTag == "-" {
241				continue
242			}
243
244			fieldName := field.Name
245			required := true
246
247			// Parse JSON tag
248			if jsonTag != "" {
249				parts := strings.Split(jsonTag, ",")
250				if parts[0] != "" {
251					fieldName = parts[0]
252				}
253
254				// Check for omitempty
255				if slices.Contains(parts[1:], "omitempty") {
256					required = false
257				}
258			} else {
259				// Convert field name to snake_case for JSON
260				fieldName = toSnakeCase(fieldName)
261			}
262
263			fieldSchema := generateSchemaRecursive(field.Type, visited)
264
265			// Add description from struct tag if available
266			if desc := field.Tag.Get("description"); desc != "" {
267				fieldSchema.Description = desc
268			}
269
270			// Add enum values from struct tag if available
271			if enumTag := field.Tag.Get("enum"); enumTag != "" {
272				enumValues := strings.Split(enumTag, ",")
273				fieldSchema.Enum = make([]any, len(enumValues))
274				for i, v := range enumValues {
275					fieldSchema.Enum[i] = strings.TrimSpace(v)
276				}
277			}
278
279			schema.Properties[fieldName] = &fieldSchema
280
281			if required {
282				schema.Required = append(schema.Required, fieldName)
283			}
284		}
285
286		return schema
287	case reflect.Interface:
288		return Schema{Type: "object"}
289	default:
290		return Schema{Type: "object"}
291	}
292}
293
294// toSnakeCase converts PascalCase to snake_case.
295func toSnakeCase(s string) string {
296	var result strings.Builder
297	for i, r := range s {
298		if i > 0 && r >= 'A' && r <= 'Z' {
299			result.WriteByte('_')
300		}
301		result.WriteRune(r)
302	}
303	return strings.ToLower(result.String())
304}