tool.go

  1// WIP NEED TO REVISIT
  2package ai
  3
  4import (
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8)
  9
 10// AgentTool represents a function that can be called by a language model.
 11type AgentTool interface {
 12	Name() string
 13	Description() string
 14	InputSchema() Schema
 15	Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error)
 16}
 17
 18// Schema represents a JSON schema for tool input validation.
 19type Schema struct {
 20	Type        string             `json:"type"`
 21	Properties  map[string]*Schema `json:"properties,omitempty"`
 22	Required    []string           `json:"required,omitempty"`
 23	Items       *Schema            `json:"items,omitempty"`
 24	Description string             `json:"description,omitempty"`
 25	Enum        []any              `json:"enum,omitempty"`
 26	Format      string             `json:"format,omitempty"`
 27	Minimum     *float64           `json:"minimum,omitempty"`
 28	Maximum     *float64           `json:"maximum,omitempty"`
 29	MinLength   *int               `json:"minLength,omitempty"`
 30	MaxLength   *int               `json:"maxLength,omitempty"`
 31}
 32
 33// BasicTool provides a basic implementation of the Tool interface
 34//
 35// Example usage:
 36//
 37//	calculator := &tools.BasicTool{
 38//	    ToolName:        "calculate",
 39//	    ToolDescription: "Evaluates mathematical expressions",
 40//	    ToolInputSchema: tools.Schema{
 41//	        Type: "object",
 42//	        Properties: map[string]*tools.Schema{
 43//	            "expression": {
 44//	                Type:        "string",
 45//	                Description: "Mathematical expression to evaluate",
 46//	            },
 47//	        },
 48//	        Required: []string{"expression"},
 49//	    },
 50//	    ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
 51//	        var req struct {
 52//	            Expression string `json:"expression"`
 53//	        }
 54//	        if err := json.Unmarshal(input, &req); err != nil {
 55//	            return nil, err
 56//	        }
 57//	        result := evaluateExpression(req.Expression)
 58//	        return json.Marshal(map[string]any{"result": result})
 59//	    },
 60//	}
 61type BasicTool struct {
 62	ToolName        string
 63	ToolDescription string
 64	ToolInputSchema Schema
 65	ExecuteFunc     func(context.Context, json.RawMessage) (json.RawMessage, error)
 66}
 67
 68// Name returns the tool name.
 69func (t *BasicTool) Name() string {
 70	return t.ToolName
 71}
 72
 73// Description returns the tool description.
 74func (t *BasicTool) Description() string {
 75	return t.ToolDescription
 76}
 77
 78// InputSchema returns the tool input schema.
 79func (t *BasicTool) InputSchema() Schema {
 80	return t.ToolInputSchema
 81}
 82
 83// Execute executes the tool with the given input.
 84func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
 85	if t.ExecuteFunc == nil {
 86		return nil, fmt.Errorf("tool %s has no execute function", t.ToolName)
 87	}
 88	return t.ExecuteFunc(ctx, input)
 89}
 90
 91// ToolBuilder provides a fluent interface for building tools.
 92type ToolBuilder struct {
 93	tool *BasicTool
 94}
 95
 96// NewTool creates a new tool builder.
 97func NewTool(name string) *ToolBuilder {
 98	return &ToolBuilder{
 99		tool: &BasicTool{
100			ToolName: name,
101		},
102	}
103}
104
105// Description sets the tool description.
106func (b *ToolBuilder) Description(desc string) *ToolBuilder {
107	b.tool.ToolDescription = desc
108	return b
109}
110
111// InputSchema sets the tool input schema.
112func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder {
113	b.tool.ToolInputSchema = schema
114	return b
115}
116
117// Execute sets the tool execution function.
118func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder {
119	b.tool.ExecuteFunc = fn
120	return b
121}
122
123// Build creates the final tool.
124func (b *ToolBuilder) Build() AgentTool {
125	return b.tool
126}
127
128// SchemaBuilder provides a fluent interface for building JSON schemas.
129type SchemaBuilder struct {
130	schema Schema
131}
132
133// NewSchema creates a new schema builder.
134func NewSchema(schemaType string) *SchemaBuilder {
135	return &SchemaBuilder{
136		schema: Schema{
137			Type: schemaType,
138		},
139	}
140}
141
142// Object creates a schema builder for an object type.
143func Object() *SchemaBuilder {
144	return NewSchema("object")
145}
146
147// String creates a schema builder for a string type.
148func String() *SchemaBuilder {
149	return NewSchema("string")
150}
151
152// Number creates a schema builder for a number type.
153func Number() *SchemaBuilder {
154	return NewSchema("number")
155}
156
157// Array creates a schema builder for an array type.
158func Array() *SchemaBuilder {
159	return NewSchema("array")
160}
161
162// Description sets the schema description.
163func (b *SchemaBuilder) Description(desc string) *SchemaBuilder {
164	b.schema.Description = desc
165	return b
166}
167
168// Properties sets the schema properties.
169func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder {
170	b.schema.Properties = props
171	return b
172}
173
174// Property adds a property to the schema.
175func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder {
176	if b.schema.Properties == nil {
177		b.schema.Properties = make(map[string]*Schema)
178	}
179	b.schema.Properties[name] = schema
180	return b
181}
182
183// Required marks fields as required.
184func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder {
185	b.schema.Required = append(b.schema.Required, fields...)
186	return b
187}
188
189// Items sets the schema for array items.
190func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder {
191	b.schema.Items = schema
192	return b
193}
194
195// Enum sets allowed values for the schema.
196func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder {
197	b.schema.Enum = values
198	return b
199}
200
201// Format sets the string format.
202func (b *SchemaBuilder) Format(format string) *SchemaBuilder {
203	b.schema.Format = format
204	return b
205}
206
207// Min sets the minimum value.
208func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder {
209	b.schema.Minimum = &minimum
210	return b
211}
212
213// Max sets the maximum value.
214func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder {
215	b.schema.Maximum = &maximum
216	return b
217}
218
219// MinLength sets the minimum string length.
220func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder {
221	b.schema.MinLength = &minimum
222	return b
223}
224
225// MaxLength sets the maximum string length.
226func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder {
227	b.schema.MaxLength = &maximum
228	return b
229}
230
231// Build creates the final schema.
232func (b *SchemaBuilder) Build() *Schema {
233	return &b.schema
234}