Detailed changes
@@ -8,8 +8,6 @@ import (
"maps"
"slices"
"sync"
-
- "github.com/charmbracelet/crush/internal/llm/tools"
)
type StepResult struct {
@@ -100,7 +98,7 @@ type PrepareStepResult struct {
type ToolCallRepairOptions struct {
OriginalToolCall ToolCallContent
ValidationError error
- AvailableTools []tools.BaseTool
+ AvailableTools []AgentTool
SystemPrompt string
Messages []Message
}
@@ -123,7 +121,7 @@ type AgentSettings struct {
providerOptions ProviderOptions
// TODO: add support for provider tools
- tools []tools.BaseTool
+ tools []AgentTool
maxRetries *int
model LanguageModel
@@ -548,13 +546,13 @@ func toResponseMessages(content []Content) []Message {
return messages
}
-func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
// Create a map for quick tool lookup
- toolMap := make(map[string]tools.BaseTool)
+ toolMap := make(map[string]AgentTool)
for _, tool := range allTools {
toolMap[tool.Info().Name] = tool
}
@@ -604,7 +602,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
}
// Execute the tool
- result, err := tool.Run(ctx, tools.ToolCall{
+ result, err := tool.Run(ctx, ToolCall{
ID: call.ToolCallID,
Name: call.ToolName,
Input: call.Input,
@@ -616,6 +614,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
Result: ToolResultOutputContentError{
Error: err,
},
+ ClientMetadata: result.Metadata,
ProviderExecuted: false,
}
if toolResultCallback != nil {
@@ -632,6 +631,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
Result: ToolResultOutputContentError{
Error: errors.New(result.Content),
},
+ ClientMetadata: result.Metadata,
ProviderExecuted: false,
}
@@ -645,6 +645,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
Result: ToolResultOutputContentText{
Text: result.Content,
},
+ ClientMetadata: result.Metadata,
ProviderExecuted: false,
}
if toolResultCallback != nil {
@@ -821,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
return agentResult, nil
}
-func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
+func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
var preparedTools []Tool
// If explicitly disabling all tools, return no tools
@@ -850,7 +851,7 @@ func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disab
}
// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
-func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
+func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
if err := a.validateToolCall(toolCall, availableTools); err == nil {
return toolCall
} else {
@@ -878,8 +879,8 @@ func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCall
}
// validateToolCall validates a tool call against available tools and their schemas
-func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
- var tool tools.BaseTool
+func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
+ var tool AgentTool
for _, t := range availableTools {
if t.Info().Name == toolCall.ToolName {
tool = t
@@ -966,7 +967,7 @@ func WithFrequencyPenalty(penalty float64) agentOption {
}
}
-func WithTools(tools ...tools.BaseTool) agentOption {
+func WithTools(tools ...AgentTool) agentOption {
return func(s *AgentSettings) {
s.tools = append(s.tools, tools...)
}
@@ -6,7 +6,6 @@ import (
"fmt"
"testing"
- "github.com/charmbracelet/crush/internal/llm/tools"
"github.com/stretchr/testify/require"
)
@@ -14,8 +13,8 @@ import (
type EchoTool struct{}
// Info returns the tool information
-func (e *EchoTool) Info() tools.ToolInfo {
- return tools.ToolInfo{
+func (e *EchoTool) Info() ToolInfo {
+ return ToolInfo{
Name: "echo",
Description: "Echo back the provided message",
Parameters: map[string]any{
@@ -29,20 +28,20 @@ func (e *EchoTool) Info() tools.ToolInfo {
}
// Run executes the echo tool
-func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
var input struct {
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
- return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
+ return NewTextErrorResponse("Invalid input: " + err.Error()), nil
}
if input.Message == "" {
- return tools.NewTextErrorResponse("Message cannot be empty"), nil
+ return NewTextErrorResponse("Message cannot be empty"), nil
}
- return tools.NewTextResponse("Echo: " + input.Message), nil
+ return NewTextResponse("Echo: " + input.Message), nil
}
// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
@@ -7,7 +7,6 @@ import (
"fmt"
"testing"
- "github.com/charmbracelet/crush/internal/llm/tools"
"github.com/stretchr/testify/require"
)
@@ -17,11 +16,11 @@ type mockTool struct {
description string
parameters map[string]any
required []string
- executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
+ executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
}
-func (m *mockTool) Info() tools.ToolInfo {
- return tools.ToolInfo{
+func (m *mockTool) Info() ToolInfo {
+ return ToolInfo{
Name: m.name,
Description: m.description,
Parameters: m.parameters,
@@ -29,11 +28,11 @@ func (m *mockTool) Info() tools.ToolInfo {
}
}
-func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
if m.executeFunc != nil {
return m.executeFunc(ctx, call)
}
- return tools.ToolResponse{Content: "mock result", IsError: false}, nil
+ return ToolResponse{Content: "mock result", IsError: false}, nil
}
// Mock language model for testing
@@ -78,22 +77,20 @@ func (m *mockLanguageModel) Model() string {
func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
t.Parallel()
- tool1 := &mockTool{
- name: "tool1",
- description: "Test tool",
- parameters: map[string]any{
- "value": map[string]any{"type": "string"},
- },
- required: []string{"value"},
- executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
- var input map[string]any
- err := json.Unmarshal([]byte(call.Input), &input)
- require.NoError(t, err)
- require.Equal(t, "value", input["value"])
- return tools.ToolResponse{Content: "result1", IsError: false}, nil
- },
+ // Create a type-safe tool using the new API
+ type TestInput struct {
+ Value string `json:"value" description:"Test value"`
}
+ tool1 := NewTypedToolFunc(
+ "tool1",
+ "Test tool",
+ func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+ require.Equal(t, "value", input.Value)
+ return ToolResponse{Content: "result1", IsError: false}, nil
+ },
+ )
+
model := &mockLanguageModel{
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
return &Response{
@@ -213,24 +210,31 @@ func TestAgent_Generate_ResultText(t *testing.T) {
func TestAgent_Generate_ResultToolCalls(t *testing.T) {
t.Parallel()
- tool1 := &mockTool{
- name: "tool1",
- description: "Test tool 1",
- parameters: map[string]any{
- "value": map[string]any{"type": "string"},
- },
- required: []string{"value"},
+ // Create type-safe tools using the new API
+ type Tool1Input struct {
+ Value string `json:"value" description:"Test value"`
}
- tool2 := &mockTool{
- name: "tool2",
- description: "Test tool 2",
- parameters: map[string]any{
- "somethingElse": map[string]any{"type": "string"},
- },
- required: []string{"somethingElse"},
+ type Tool2Input struct {
+ SomethingElse string `json:"somethingElse" description:"Another test value"`
}
+ tool1 := NewTypedToolFunc(
+ "tool1",
+ "Test tool 1",
+ func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
+ return ToolResponse{Content: "result1", IsError: false}, nil
+ },
+ )
+
+ tool2 := NewTypedToolFunc(
+ "tool2",
+ "Test tool 2",
+ func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
+ return ToolResponse{Content: "result2", IsError: false}, nil
+ },
+ )
+
model := &mockLanguageModel{
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
// Verify tools are passed correctly
@@ -291,22 +295,20 @@ func TestAgent_Generate_ResultToolCalls(t *testing.T) {
func TestAgent_Generate_ResultToolResults(t *testing.T) {
t.Parallel()
- tool1 := &mockTool{
- name: "tool1",
- description: "Test tool",
- parameters: map[string]any{
- "value": map[string]any{"type": "string"},
- },
- required: []string{"value"},
- executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
- var input map[string]any
- err := json.Unmarshal([]byte(call.Input), &input)
- require.NoError(t, err)
- require.Equal(t, "value", input["value"])
- return tools.ToolResponse{Content: "result1", IsError: false}, nil
- },
+ // Create type-safe tool using the new API
+ type TestInput struct {
+ Value string `json:"value" description:"Test value"`
}
+ tool1 := NewTypedToolFunc(
+ "tool1",
+ "Test tool",
+ func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+ require.Equal(t, "value", input.Value)
+ return ToolResponse{Content: "result1", IsError: false}, nil
+ },
+ )
+
model := &mockLanguageModel{
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
// Verify tools and tool choice
@@ -366,22 +368,20 @@ func TestAgent_Generate_ResultToolResults(t *testing.T) {
func TestAgent_Generate_MultipleSteps(t *testing.T) {
t.Parallel()
- tool1 := &mockTool{
- name: "tool1",
- description: "Test tool",
- parameters: map[string]any{
- "value": map[string]any{"type": "string"},
- },
- required: []string{"value"},
- executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
- var input map[string]any
- err := json.Unmarshal([]byte(call.Input), &input)
- require.NoError(t, err)
- require.Equal(t, "value", input["value"])
- return tools.ToolResponse{Content: "result1", IsError: false}, nil
- },
+ // Create type-safe tool using the new API
+ type TestInput struct {
+ Value string `json:"value" description:"Test value"`
}
+ tool1 := NewTypedToolFunc(
+ "tool1",
+ "Test tool",
+ func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+ require.Equal(t, "value", input.Value)
+ return ToolResponse{Content: "result1", IsError: false}, nil
+ },
+ )
+
callCount := 0
model := &mockLanguageModel{
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
@@ -1289,8 +1289,8 @@ func TestToolCallRepair(t *testing.T) {
"value": map[string]any{"type": "string"},
},
required: []string{"value"},
- executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
- return tools.ToolResponse{Content: "success", IsError: false}, nil
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ return ToolResponse{Content: "success", IsError: false}, nil
},
}
@@ -1379,8 +1379,8 @@ func TestToolCallRepair(t *testing.T) {
"value": map[string]any{"type": "string"},
},
required: []string{"value"},
- executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
- return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ return ToolResponse{Content: "repaired_success", IsError: false}, nil
},
}
@@ -334,7 +334,8 @@ type ToolResultContent struct {
// Name of the tool that generated this result.
ToolName string `json:"tool_name"`
// Result of the tool call. This is a JSON-serializable object.
- Result ToolResultOutputContent `json:"result"`
+ Result ToolResultOutputContent `json:"result"`
+ ClientMetadata string `json:"client_metadata"` // Metadata from the client that executed the tool
// Whether the tool result was generated by the provider.
// If this flag is set to true, the tool result was generated by the provider.
// If this flag is not set or is false, the tool result was generated by the client.
@@ -1,20 +1,13 @@
-// WIP NEED TO REVISIT
package ai
import (
"context"
"encoding/json"
"fmt"
+ "reflect"
+ "strings"
)
-// AgentTool represents a function that can be called by a language model.
-type AgentTool interface {
- Name() string
- Description() string
- InputSchema() Schema
- Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error)
-}
-
// Schema represents a JSON schema for tool input validation.
type Schema struct {
Type string `json:"type"`
@@ -30,205 +23,284 @@ type Schema struct {
MaxLength *int `json:"maxLength,omitempty"`
}
-// BasicTool provides a basic implementation of the Tool interface
-//
-// Example usage:
-//
-// calculator := &tools.BasicTool{
-// ToolName: "calculate",
-// ToolDescription: "Evaluates mathematical expressions",
-// ToolInputSchema: tools.Schema{
-// Type: "object",
-// Properties: map[string]*tools.Schema{
-// "expression": {
-// Type: "string",
-// Description: "Mathematical expression to evaluate",
-// },
-// },
-// Required: []string{"expression"},
-// },
-// ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
-// var req struct {
-// Expression string `json:"expression"`
-// }
-// if err := json.Unmarshal(input, &req); err != nil {
-// return nil, err
-// }
-// result := evaluateExpression(req.Expression)
-// return json.Marshal(map[string]any{"result": result})
-// },
-// }
-type BasicTool struct {
- ToolName string
- ToolDescription string
- ToolInputSchema Schema
- ExecuteFunc func(context.Context, json.RawMessage) (json.RawMessage, error)
-}
-
-// Name returns the tool name.
-func (t *BasicTool) Name() string {
- return t.ToolName
-}
-
-// Description returns the tool description.
-func (t *BasicTool) Description() string {
- return t.ToolDescription
-}
-
-// InputSchema returns the tool input schema.
-func (t *BasicTool) InputSchema() Schema {
- return t.ToolInputSchema
-}
-
-// Execute executes the tool with the given input.
-func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
- if t.ExecuteFunc == nil {
- return nil, fmt.Errorf("tool %s has no execute function", t.ToolName)
- }
- return t.ExecuteFunc(ctx, input)
+// ToolInfo represents tool metadata, matching the existing pattern.
+type ToolInfo struct {
+ Name string
+ Description string
+ Parameters map[string]any
+ Required []string
}
-// ToolBuilder provides a fluent interface for building tools.
-type ToolBuilder struct {
- tool *BasicTool
+// ToolCall represents a tool invocation, matching the existing pattern.
+type ToolCall struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input string `json:"input"`
}
-// NewTool creates a new tool builder.
-func NewTool(name string) *ToolBuilder {
- return &ToolBuilder{
- tool: &BasicTool{
- ToolName: name,
- },
+// ToolResponse represents the response from a tool execution, matching the existing pattern.
+type ToolResponse struct {
+ Type string `json:"type"`
+ Content string `json:"content"`
+ Metadata string `json:"metadata,omitempty"`
+ IsError bool `json:"is_error"`
+}
+
+// NewTextResponse creates a text response.
+func NewTextResponse(content string) ToolResponse {
+ return ToolResponse{
+ Type: "text",
+ Content: content,
}
}
-// Description sets the tool description.
-func (b *ToolBuilder) Description(desc string) *ToolBuilder {
- b.tool.ToolDescription = desc
- return b
+// NewTextErrorResponse creates an error response.
+func NewTextErrorResponse(content string) ToolResponse {
+ return ToolResponse{
+ Type: "text",
+ Content: content,
+ IsError: true,
+ }
}
-// InputSchema sets the tool input schema.
-func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder {
- b.tool.ToolInputSchema = schema
- return b
+// WithResponseMetadata adds metadata to a response.
+func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
+ if metadata != nil {
+ metadataBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return response
+ }
+ response.Metadata = string(metadataBytes)
+ }
+ return response
}
-// Execute sets the tool execution function.
-func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder {
- b.tool.ExecuteFunc = fn
- return b
+// AgentTool represents a tool that can be called by a language model.
+// This matches the existing BaseTool interface pattern.
+type AgentTool interface {
+ Info() ToolInfo
+ Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}
-// Build creates the final tool.
-func (b *ToolBuilder) Build() AgentTool {
- return b.tool
+// NewTypedToolFunc creates a typed tool from a function with automatic schema generation.
+// This is the recommended way to create tools.
+func NewTypedToolFunc[TInput any](
+ name string,
+ description string,
+ fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
+) AgentTool {
+ var input TInput
+ schema := generateSchema(reflect.TypeOf(input))
+
+ return &funcToolWrapper[TInput]{
+ name: name,
+ description: description,
+ fn: fn,
+ schema: schema,
+ }
}
-// SchemaBuilder provides a fluent interface for building JSON schemas.
-type SchemaBuilder struct {
- schema Schema
+// funcToolWrapper wraps a function to implement the AgentTool interface.
+type funcToolWrapper[TInput any] struct {
+ name string
+ description string
+ fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
+ schema Schema
}
-// NewSchema creates a new schema builder.
-func NewSchema(schemaType string) *SchemaBuilder {
- return &SchemaBuilder{
- schema: Schema{
- Type: schemaType,
- },
+func (w *funcToolWrapper[TInput]) Info() ToolInfo {
+ return ToolInfo{
+ Name: w.name,
+ Description: w.description,
+ Parameters: schemaToParameters(w.schema),
+ Required: w.schema.Required,
}
}
-// Object creates a schema builder for an object type.
-func Object() *SchemaBuilder {
- return NewSchema("object")
-}
+func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
+ var input TInput
+ if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
+ }
-// String creates a schema builder for a string type.
-func String() *SchemaBuilder {
- return NewSchema("string")
+ return w.fn(ctx, input, params)
}
-// Number creates a schema builder for a number type.
-func Number() *SchemaBuilder {
- return NewSchema("number")
-}
+// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
+func schemaToParameters(schema Schema) map[string]any {
+ if schema.Type != "object" || schema.Properties == nil {
+ return map[string]any{}
+ }
-// Array creates a schema builder for an array type.
-func Array() *SchemaBuilder {
- return NewSchema("array")
-}
+ params := make(map[string]any)
+ for name, propSchema := range schema.Properties {
+ param := map[string]any{
+ "type": propSchema.Type,
+ }
-// Description sets the schema description.
-func (b *SchemaBuilder) Description(desc string) *SchemaBuilder {
- b.schema.Description = desc
- return b
-}
+ if propSchema.Description != "" {
+ param["description"] = propSchema.Description
+ }
-// Properties sets the schema properties.
-func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder {
- b.schema.Properties = props
- return b
-}
+ if len(propSchema.Enum) > 0 {
+ param["enum"] = propSchema.Enum
+ }
-// Property adds a property to the schema.
-func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder {
- if b.schema.Properties == nil {
- b.schema.Properties = make(map[string]*Schema)
- }
- b.schema.Properties[name] = schema
- return b
-}
+ if propSchema.Format != "" {
+ param["format"] = propSchema.Format
+ }
-// Required marks fields as required.
-func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder {
- b.schema.Required = append(b.schema.Required, fields...)
- return b
-}
+ if propSchema.Minimum != nil {
+ param["minimum"] = *propSchema.Minimum
+ }
-// Items sets the schema for array items.
-func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder {
- b.schema.Items = schema
- return b
-}
+ if propSchema.Maximum != nil {
+ param["maximum"] = *propSchema.Maximum
+ }
-// Enum sets allowed values for the schema.
-func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder {
- b.schema.Enum = values
- return b
-}
+ if propSchema.MinLength != nil {
+ param["minLength"] = *propSchema.MinLength
+ }
-// Format sets the string format.
-func (b *SchemaBuilder) Format(format string) *SchemaBuilder {
- b.schema.Format = format
- return b
-}
+ if propSchema.MaxLength != nil {
+ param["maxLength"] = *propSchema.MaxLength
+ }
-// Min sets the minimum value.
-func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder {
- b.schema.Minimum = &minimum
- return b
-}
+ if propSchema.Items != nil {
+ param["items"] = schemaToParameters(*propSchema.Items)
+ }
-// Max sets the maximum value.
-func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder {
- b.schema.Maximum = &maximum
- return b
+ params[name] = param
+ }
+
+ return params
}
-// MinLength sets the minimum string length.
-func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder {
- b.schema.MinLength = &minimum
- return b
+// generateSchema automatically generates a JSON schema from a Go type.
+func generateSchema(t reflect.Type) Schema {
+ return generateSchemaRecursive(t, make(map[reflect.Type]bool))
}
-// MaxLength sets the maximum string length.
-func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder {
- b.schema.MaxLength = &maximum
- return b
+func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
+ // Handle pointers
+ if t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+
+ // Prevent infinite recursion
+ if visited[t] {
+ return Schema{Type: "object"}
+ }
+ visited[t] = true
+ defer delete(visited, t)
+
+ switch t.Kind() {
+ case reflect.String:
+ return Schema{Type: "string"}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return Schema{Type: "integer"}
+ case reflect.Float32, reflect.Float64:
+ return Schema{Type: "number"}
+ case reflect.Bool:
+ return Schema{Type: "boolean"}
+ case reflect.Slice, reflect.Array:
+ itemSchema := generateSchemaRecursive(t.Elem(), visited)
+ return Schema{
+ Type: "array",
+ Items: &itemSchema,
+ }
+ case reflect.Map:
+ if t.Key().Kind() == reflect.String {
+ valueSchema := generateSchemaRecursive(t.Elem(), visited)
+ return Schema{
+ Type: "object",
+ Properties: map[string]*Schema{
+ "*": &valueSchema,
+ },
+ }
+ }
+ return Schema{Type: "object"}
+ case reflect.Struct:
+ schema := Schema{
+ Type: "object",
+ Properties: make(map[string]*Schema),
+ }
+
+ for i := range t.NumField() {
+ field := t.Field(i)
+
+ // Skip unexported fields
+ if !field.IsExported() {
+ continue
+ }
+
+ jsonTag := field.Tag.Get("json")
+ if jsonTag == "-" {
+ continue
+ }
+
+ fieldName := field.Name
+ required := true
+
+ // Parse JSON tag
+ if jsonTag != "" {
+ parts := strings.Split(jsonTag, ",")
+ if parts[0] != "" {
+ fieldName = parts[0]
+ }
+
+ // Check for omitempty
+ for _, part := range parts[1:] {
+ if part == "omitempty" {
+ required = false
+ break
+ }
+ }
+ } else {
+ // Convert field name to snake_case for JSON
+ fieldName = toSnakeCase(fieldName)
+ }
+
+ fieldSchema := generateSchemaRecursive(field.Type, visited)
+
+ // Add description from struct tag if available
+ if desc := field.Tag.Get("description"); desc != "" {
+ fieldSchema.Description = desc
+ }
+
+ // Add enum values from struct tag if available
+ if enumTag := field.Tag.Get("enum"); enumTag != "" {
+ enumValues := strings.Split(enumTag, ",")
+ fieldSchema.Enum = make([]any, len(enumValues))
+ for i, v := range enumValues {
+ fieldSchema.Enum[i] = strings.TrimSpace(v)
+ }
+ }
+
+ schema.Properties[fieldName] = &fieldSchema
+
+ if required {
+ schema.Required = append(schema.Required, fieldName)
+ }
+ }
+
+ return schema
+ case reflect.Interface:
+ return Schema{Type: "object"}
+ default:
+ return Schema{Type: "object"}
+ }
}
-// Build creates the final schema.
-func (b *SchemaBuilder) Build() *Schema {
- return &b.schema
+// toSnakeCase converts PascalCase to snake_case.
+func toSnakeCase(s string) string {
+ var result strings.Builder
+ for i, r := range s {
+ if i > 0 && r >= 'A' && r <= 'Z' {
+ result.WriteByte('_')
+ }
+ result.WriteRune(r)
+ }
+ return strings.ToLower(result.String())
}
@@ -0,0 +1,211 @@
+package ai
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+// Example of a simple typed tool using the function approach
+type CalculatorInput struct {
+ Expression string `json:"expression" description:"Mathematical expression to evaluate"`
+}
+
+func TestTypedToolFuncExample(t *testing.T) {
+ // Create a typed tool using the function API
+ tool := NewTypedToolFunc(
+ "calculator",
+ "Evaluates simple mathematical expressions",
+ func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
+ if input.Expression == "2+2" {
+ return NewTextResponse("4"), nil
+ }
+ return NewTextErrorResponse("unsupported expression"), nil
+ },
+ )
+
+ // Check the tool info
+ info := tool.Info()
+ if info.Name != "calculator" {
+ t.Errorf("Expected tool name 'calculator', got %s", info.Name)
+ }
+ if len(info.Required) != 1 || info.Required[0] != "expression" {
+ t.Errorf("Expected required field 'expression', got %v", info.Required)
+ }
+
+ // Test execution
+ call := ToolCall{
+ ID: "test-1",
+ Name: "calculator",
+ Input: `{"expression": "2+2"}`,
+ }
+
+ result, err := tool.Run(context.Background(), call)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if result.Content != "4" {
+ t.Errorf("Expected result '4', got %s", result.Content)
+ }
+ if result.IsError {
+ t.Errorf("Expected successful result, got error")
+ }
+}
+
+func TestEnumToolExample(t *testing.T) {
+ type WeatherInput struct {
+ Location string `json:"location" description:"City name"`
+ Units string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
+ }
+
+ // Create a weather tool with enum support
+ tool := NewTypedToolFunc(
+ "weather",
+ "Gets current weather for a location",
+ func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
+ temp := "22°C"
+ if input.Units == "fahrenheit" {
+ temp = "72°F"
+ }
+ return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
+ },
+ )
+
+ // Check that the schema includes enum values
+ info := tool.Info()
+ unitsParam, ok := info.Parameters["units"].(map[string]any)
+ if !ok {
+ t.Fatal("Expected units parameter to exist")
+ }
+ enumValues, ok := unitsParam["enum"].([]any)
+ if !ok || len(enumValues) != 2 {
+ t.Errorf("Expected 2 enum values, got %v", enumValues)
+ }
+
+ // Test execution with enum value
+ call := ToolCall{
+ ID: "test-2",
+ Name: "weather",
+ Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
+ }
+
+ result, err := tool.Run(context.Background(), call)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !strings.Contains(result.Content, "San Francisco") {
+ t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
+ }
+ if !strings.Contains(result.Content, "72°F") {
+ t.Errorf("Expected result to contain '72°F', got %s", result.Content)
+ }
+}
+
+func TestEnumSupport(t *testing.T) {
+ // Test enum via struct tags
+ type WeatherInput struct {
+ Location string `json:"location" description:"City name"`
+ Units string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
+ Format string `json:"format,omitempty" enum:"json,xml,text"`
+ }
+
+ schema := generateSchema(reflect.TypeOf(WeatherInput{}))
+
+ if schema.Type != "object" {
+ t.Errorf("Expected schema type 'object', got %s", schema.Type)
+ }
+
+ // Check units field has enum values
+ unitsSchema := schema.Properties["units"]
+ if unitsSchema == nil {
+ t.Fatal("Expected units property to exist")
+ }
+ if len(unitsSchema.Enum) != 3 {
+ t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
+ }
+ expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
+ for i, expected := range expectedUnits {
+ if unitsSchema.Enum[i] != expected {
+ t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
+ }
+ }
+
+ // Check required fields (format should not be required due to omitempty)
+ expectedRequired := []string{"location", "units"}
+ if len(schema.Required) != len(expectedRequired) {
+ t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
+ }
+}
+
+func TestSchemaToParameters(t *testing.T) {
+ schema := Schema{
+ Type: "object",
+ Properties: map[string]*Schema{
+ "name": {
+ Type: "string",
+ Description: "The name field",
+ },
+ "age": {
+ Type: "integer",
+ Minimum: func() *float64 { v := 0.0; return &v }(),
+ Maximum: func() *float64 { v := 120.0; return &v }(),
+ },
+ "tags": {
+ Type: "array",
+ Items: &Schema{
+ Type: "string",
+ },
+ },
+ "priority": {
+ Type: "string",
+ Enum: []any{"low", "medium", "high"},
+ },
+ },
+ Required: []string{"name"},
+ }
+
+ params := schemaToParameters(schema)
+
+ // Check name parameter
+ nameParam, ok := params["name"].(map[string]any)
+ if !ok {
+ t.Fatal("Expected name parameter to exist")
+ }
+ if nameParam["type"] != "string" {
+ t.Errorf("Expected name type 'string', got %v", nameParam["type"])
+ }
+ if nameParam["description"] != "The name field" {
+ t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
+ }
+
+ // Check age parameter with min/max
+ ageParam, ok := params["age"].(map[string]any)
+ if !ok {
+ t.Fatal("Expected age parameter to exist")
+ }
+ if ageParam["type"] != "integer" {
+ t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
+ }
+ if ageParam["minimum"] != 0.0 {
+ t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
+ }
+ if ageParam["maximum"] != 120.0 {
+ t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
+ }
+
+ // Check priority parameter with enum
+ priorityParam, ok := params["priority"].(map[string]any)
+ if !ok {
+ t.Fatal("Expected priority parameter to exist")
+ }
+ if priorityParam["type"] != "string" {
+ t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
+ }
+ enumValues, ok := priorityParam["enum"].([]any)
+ if !ok || len(enumValues) != 3 {
+ t.Errorf("Expected 3 enum values, got %v", enumValues)
+ }
+}
+
@@ -0,0 +1,15 @@
+package tools
+
+import "errors"
+
+type Permission struct {
+ ToolCallID string
+ ToolName string
+ Path string
+ Action string
+ Description string
+ Params any
+}
+type PermissionAsk = func(ask Permission) bool
+
+var ErrorPermissionDenied = errors.New("permission denied")
@@ -0,0 +1,258 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/charmbracelet/crush/internal/fsext"
+)
+
+const (
+ MaxLSFiles = 1000
+ LSToolName = "ls"
+)
+
+type LSParams struct {
+ Path string `json:"path" description:"The path to the directory to list (defaults to current working directory)"`
+ Ignore []string `json:"ignore,omitempty" description:"List of glob patterns to ignore"`
+}
+
+type LSPermissionsParams struct {
+ Path string `json:"path"`
+ Ignore []string `json:"ignore"`
+}
+
+type TreeNode struct {
+ Name string `json:"name"`
+ Path string `json:"path"`
+ Type string `json:"type"`
+ Children []*TreeNode `json:"children,omitempty"`
+}
+
+type LSResponseMetadata struct {
+ NumberOfFiles int `json:"number_of_files"`
+ Truncated bool `json:"truncated"`
+}
+
+func NewLSTool(permissionAsk PermissionAsk, workingDir string) ai.AgentTool {
+ return ai.NewTypedToolFunc(
+ LSToolName,
+ `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to explore the structure of a directory
+- Helpful for understanding the organization of a project
+- Good first step when getting familiar with a new codebase
+
+HOW TO USE:
+- Provide a path to list (defaults to current working directory)
+- Optionally specify glob patterns to ignore
+- Results are displayed in a tree structure
+
+FEATURES:
+- Displays a hierarchical view of files and directories
+- Automatically skips hidden files/directories (starting with '.')
+- Skips common system directories like __pycache__
+- Can filter out files matching specific patterns
+
+LIMITATIONS:
+- Results are limited to 1000 files
+- Very large directories will be truncated
+- Does not show file sizes or permissions
+- Cannot recursively list all directories in a large project
+
+WINDOWS NOTES:
+- Hidden file detection uses Unix convention (files starting with '.')
+- Windows-specific hidden files (with hidden attribute) are not automatically skipped
+- Common Windows directories like System32, Program Files are not in default ignore list
+- Path separators are handled automatically (both / and \ work)
+
+TIPS:
+- Use Glob tool for finding files by name patterns instead of browsing
+- Use Grep tool for searching file contents
+- Combine with other tools for more effective exploration`,
+ func(ctx context.Context, params LSParams, call ai.ToolCall) (ai.ToolResponse, error) {
+ searchPath := params.Path
+ if searchPath == "" {
+ searchPath = workingDir
+ }
+
+ var err error
+ searchPath, err = fsext.Expand(searchPath)
+ if err != nil {
+ return ai.ToolResponse{}, fmt.Errorf("error expanding path: %w", err)
+ }
+
+ if !filepath.IsAbs(searchPath) {
+ searchPath = filepath.Join(workingDir, searchPath)
+ }
+
+ // Check if directory is outside working directory and request permission if needed
+ absWorkingDir, err := filepath.Abs(workingDir)
+ if err != nil {
+ return ai.ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err)
+ }
+
+ absSearchPath, err := filepath.Abs(searchPath)
+ if err != nil {
+ return ai.ToolResponse{}, fmt.Errorf("error resolving search path: %w", err)
+ }
+
+ relPath, err := filepath.Rel(absWorkingDir, absSearchPath)
+ if err != nil || strings.HasPrefix(relPath, "..") {
+ granted := permissionAsk(Permission{
+ ToolCallID: call.ID,
+ ToolName: LSToolName,
+ Path: absSearchPath,
+ Action: "list",
+ Description: fmt.Sprintf("List directory outside working directory: %s", absSearchPath),
+ Params: LSPermissionsParams(params),
+ })
+
+ if !granted {
+ return ai.ToolResponse{}, ErrorPermissionDenied
+ }
+ }
+ output, err := ListDirectoryTree(searchPath, params.Ignore)
+ if err != nil {
+ return ai.ToolResponse{}, err
+ }
+
+ // Get file count for metadata
+ files, truncated, err := fsext.ListDirectory(searchPath, params.Ignore, MaxLSFiles)
+ if err != nil {
+ return ai.ToolResponse{}, fmt.Errorf("error listing directory for metadata: %w", err)
+ }
+
+ return ai.WithResponseMetadata(
+ ai.NewTextResponse(output),
+ LSResponseMetadata{
+ NumberOfFiles: len(files),
+ Truncated: truncated,
+ },
+ ), nil
+ },
+ )
+}
+
+func ListDirectoryTree(searchPath string, ignore []string) (string, error) {
+ if _, err := os.Stat(searchPath); os.IsNotExist(err) {
+ return "", fmt.Errorf("path does not exist: %s", searchPath)
+ }
+
+ files, truncated, err := fsext.ListDirectory(searchPath, ignore, MaxLSFiles)
+ if err != nil {
+ return "", fmt.Errorf("error listing directory: %w", err)
+ }
+
+ tree := createFileTree(files, searchPath)
+ output := printTree(tree, searchPath)
+
+ if truncated {
+ output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
+ }
+
+ return output, nil
+}
+
+func createFileTree(sortedPaths []string, rootPath string) []*TreeNode {
+ root := []*TreeNode{}
+ pathMap := make(map[string]*TreeNode)
+
+ for _, path := range sortedPaths {
+ relativePath := strings.TrimPrefix(path, rootPath)
+ parts := strings.Split(relativePath, string(filepath.Separator))
+ currentPath := ""
+ var parentPath string
+
+ var cleanParts []string
+ for _, part := range parts {
+ if part != "" {
+ cleanParts = append(cleanParts, part)
+ }
+ }
+ parts = cleanParts
+
+ if len(parts) == 0 {
+ continue
+ }
+
+ for i, part := range parts {
+ if currentPath == "" {
+ currentPath = part
+ } else {
+ currentPath = filepath.Join(currentPath, part)
+ }
+
+ if _, exists := pathMap[currentPath]; exists {
+ parentPath = currentPath
+ continue
+ }
+
+ isLastPart := i == len(parts)-1
+ isDir := !isLastPart || strings.HasSuffix(relativePath, string(filepath.Separator))
+ nodeType := "file"
+ if isDir {
+ nodeType = "directory"
+ }
+ newNode := &TreeNode{
+ Name: part,
+ Path: currentPath,
+ Type: nodeType,
+ Children: []*TreeNode{},
+ }
+
+ pathMap[currentPath] = newNode
+
+ if i > 0 && parentPath != "" {
+ if parent, ok := pathMap[parentPath]; ok {
+ parent.Children = append(parent.Children, newNode)
+ }
+ } else {
+ root = append(root, newNode)
+ }
+
+ parentPath = currentPath
+ }
+ }
+
+ return root
+}
+
+func printTree(tree []*TreeNode, rootPath string) string {
+ var result strings.Builder
+
+ result.WriteString("- ")
+ result.WriteString(rootPath)
+ if rootPath[len(rootPath)-1] != '/' {
+ result.WriteByte(filepath.Separator)
+ }
+ result.WriteByte('\n')
+
+ for _, node := range tree {
+ printNode(&result, node, 1)
+ }
+
+ return result.String()
+}
+
+func printNode(builder *strings.Builder, node *TreeNode, level int) {
+ indent := strings.Repeat(" ", level)
+
+ nodeName := node.Name
+ if node.Type == "directory" {
+ nodeName = nodeName + string(filepath.Separator)
+ }
+
+ fmt.Fprintf(builder, "%s- %s\n", indent, nodeName)
+
+ if node.Type == "directory" && len(node.Children) > 0 {
+ for _, child := range node.Children {
+ printNode(builder, child, level+1)
+ }
+ }
+}