tool.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"slices"
  9	"strings"
 10)
 11
 12// Schema represents a JSON schema for tool input validation.
 13type Schema struct {
 14	Type        string             `json:"type,omitempty"`
 15	Properties  map[string]*Schema `json:"properties,omitempty"`
 16	Required    []string           `json:"required,omitempty"`
 17	Items       *Schema            `json:"items,omitempty"`
 18	Description string             `json:"description,omitempty"`
 19	Enum        []any              `json:"enum,omitempty"`
 20	Format      string             `json:"format,omitempty"`
 21	Minimum     *float64           `json:"minimum,omitempty"`
 22	Maximum     *float64           `json:"maximum,omitempty"`
 23	MinLength   *int               `json:"minLength,omitempty"`
 24	MaxLength   *int               `json:"maxLength,omitempty"`
 25}
 26
 27// ToolInfo represents tool metadata, matching the existing pattern.
 28type ToolInfo struct {
 29	Name        string
 30	Description string
 31	Parameters  map[string]any
 32	Required    []string
 33}
 34
 35// ToolCall represents a tool invocation, matching the existing pattern.
 36type ToolCall struct {
 37	ID    string `json:"id"`
 38	Name  string `json:"name"`
 39	Input string `json:"input"`
 40}
 41
 42// ToolResponse represents the response from a tool execution, matching the existing pattern.
 43type ToolResponse struct {
 44	Type     string `json:"type"`
 45	Content  string `json:"content"`
 46	Metadata string `json:"metadata,omitempty"`
 47	IsError  bool   `json:"is_error"`
 48}
 49
 50// NewTextResponse creates a text response.
 51func NewTextResponse(content string) ToolResponse {
 52	return ToolResponse{
 53		Type:    "text",
 54		Content: content,
 55	}
 56}
 57
 58// NewTextErrorResponse creates an error response.
 59func NewTextErrorResponse(content string) ToolResponse {
 60	return ToolResponse{
 61		Type:    "text",
 62		Content: content,
 63		IsError: true,
 64	}
 65}
 66
 67// WithResponseMetadata adds metadata to a response.
 68func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 69	if metadata != nil {
 70		metadataBytes, err := json.Marshal(metadata)
 71		if err != nil {
 72			return response
 73		}
 74		response.Metadata = string(metadataBytes)
 75	}
 76	return response
 77}
 78
 79// AgentTool represents a tool that can be called by a language model.
 80// This matches the existing BaseTool interface pattern.
 81type AgentTool interface {
 82	Info() ToolInfo
 83	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 84	ProviderOptions() ProviderOptions
 85	SetProviderOptions(opts ProviderOptions)
 86}
 87
 88// NewAgentTool creates a typed tool from a function with automatic schema generation.
 89// This is the recommended way to create tools.
 90func NewAgentTool[TInput any](
 91	name string,
 92	description string,
 93	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
 94) AgentTool {
 95	var input TInput
 96	schema := generateSchema(reflect.TypeOf(input))
 97
 98	return &funcToolWrapper[TInput]{
 99		name:        name,
100		description: description,
101		fn:          fn,
102		schema:      schema,
103	}
104}
105
106// funcToolWrapper wraps a function to implement the AgentTool interface.
107type funcToolWrapper[TInput any] struct {
108	name            string
109	description     string
110	fn              func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
111	schema          Schema
112	providerOptions ProviderOptions
113}
114
115func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) {
116	w.providerOptions = opts
117}
118
119func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions {
120	return w.providerOptions
121}
122
123func (w *funcToolWrapper[TInput]) Info() ToolInfo {
124	if w.schema.Required == nil {
125		w.schema.Required = []string{}
126	}
127	return ToolInfo{
128		Name:        w.name,
129		Description: w.description,
130		Parameters:  schemaToParameters(w.schema),
131		Required:    w.schema.Required,
132	}
133}
134
135func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
136	var input TInput
137	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
138		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
139	}
140
141	return w.fn(ctx, input, params)
142}
143
144// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
145func schemaToParameters(schema Schema) map[string]any {
146	if schema.Type != "object" || schema.Properties == nil {
147		return map[string]any{}
148	}
149
150	params := make(map[string]any)
151	for name, propSchema := range schema.Properties {
152		param := map[string]any{
153			"type": propSchema.Type,
154		}
155
156		if propSchema.Description != "" {
157			param["description"] = propSchema.Description
158		}
159
160		if len(propSchema.Enum) > 0 {
161			param["enum"] = propSchema.Enum
162		}
163
164		if propSchema.Format != "" {
165			param["format"] = propSchema.Format
166		}
167
168		if propSchema.Minimum != nil {
169			param["minimum"] = *propSchema.Minimum
170		}
171
172		if propSchema.Maximum != nil {
173			param["maximum"] = *propSchema.Maximum
174		}
175
176		if propSchema.MinLength != nil {
177			param["minLength"] = *propSchema.MinLength
178		}
179
180		if propSchema.MaxLength != nil {
181			param["maxLength"] = *propSchema.MaxLength
182		}
183
184		if propSchema.Items != nil {
185			param["items"] = schemaToParameters(*propSchema.Items)
186		}
187
188		params[name] = param
189	}
190
191	return params
192}
193
194// generateSchema automatically generates a JSON schema from a Go type.
195func generateSchema(t reflect.Type) Schema {
196	return generateSchemaRecursive(t, nil, make(map[reflect.Type]bool))
197}
198
199func generateSchemaRecursive(t, parent reflect.Type, visited map[reflect.Type]bool) Schema {
200	// Handle pointers
201	if t.Kind() == reflect.Pointer {
202		t = t.Elem()
203	}
204
205	// Prevent infinite recursion
206	if visited[t] {
207		return Schema{Type: "object"}
208	}
209	visited[t] = true
210	defer delete(visited, t)
211
212	switch t.Kind() {
213	case reflect.String:
214		return Schema{Type: "string"}
215	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
216		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
217		return Schema{Type: "integer"}
218	case reflect.Float32, reflect.Float64:
219		return Schema{Type: "number"}
220	case reflect.Bool:
221		return Schema{Type: "boolean"}
222	case reflect.Slice, reflect.Array:
223		itemSchema := generateSchemaRecursive(t.Elem(), t, visited)
224		return Schema{
225			Type:  "array",
226			Items: &itemSchema,
227		}
228	case reflect.Map:
229		if t.Key().Kind() == reflect.String {
230			valueSchema := generateSchemaRecursive(t.Elem(), t, visited)
231			schema := Schema{
232				Type: "object",
233				Properties: map[string]*Schema{
234					"*": &valueSchema,
235				},
236			}
237			if useBlankType(parent) {
238				schema.Type = ""
239			}
240			return schema
241		}
242		return Schema{Type: "object"}
243	case reflect.Struct:
244		schema := Schema{
245			Type:       "object",
246			Properties: make(map[string]*Schema),
247		}
248		if useBlankType(parent) {
249			schema.Type = ""
250		}
251
252		for i := range t.NumField() {
253			field := t.Field(i)
254
255			// Skip unexported fields
256			if !field.IsExported() {
257				continue
258			}
259
260			jsonTag := field.Tag.Get("json")
261			if jsonTag == "-" {
262				continue
263			}
264
265			fieldName := field.Name
266			required := true
267
268			// Parse JSON tag
269			if jsonTag != "" {
270				parts := strings.Split(jsonTag, ",")
271				if parts[0] != "" {
272					fieldName = parts[0]
273				}
274
275				// Check for omitempty
276				if slices.Contains(parts[1:], "omitempty") {
277					required = false
278				}
279			} else {
280				// Convert field name to snake_case for JSON
281				fieldName = toSnakeCase(fieldName)
282			}
283
284			fieldSchema := generateSchemaRecursive(field.Type, t, visited)
285
286			// Add description from struct tag if available
287			if desc := field.Tag.Get("description"); desc != "" {
288				fieldSchema.Description = desc
289			}
290
291			// Add enum values from struct tag if available
292			if enumTag := field.Tag.Get("enum"); enumTag != "" {
293				enumValues := strings.Split(enumTag, ",")
294				fieldSchema.Enum = make([]any, len(enumValues))
295				for i, v := range enumValues {
296					fieldSchema.Enum[i] = strings.TrimSpace(v)
297				}
298			}
299
300			schema.Properties[fieldName] = &fieldSchema
301
302			if required {
303				schema.Required = append(schema.Required, fieldName)
304			}
305		}
306
307		return schema
308	case reflect.Interface:
309		return Schema{Type: "object"}
310	default:
311		return Schema{Type: "object"}
312	}
313}
314
315// toSnakeCase converts PascalCase to snake_case.
316func toSnakeCase(s string) string {
317	var result strings.Builder
318	for i, r := range s {
319		if i > 0 && r >= 'A' && r <= 'Z' {
320			result.WriteByte('_')
321		}
322		result.WriteRune(r)
323	}
324	return strings.ToLower(result.String())
325}
326
327// NOTE(@andreynering): This is a hacky workaround for llama.cpp.
328// Ideally, we should always output `type: object` for objects, but
329// llama.cpp complains if we do for arrays of objects.
330func useBlankType(parent reflect.Type) bool {
331	if parent == nil {
332		return false
333	}
334	switch parent.Kind() {
335	case reflect.Slice, reflect.Array:
336		return true
337	default:
338		return false
339	}
340}