feat: add new tool definitions

Kujtim Hoxha created

Change summary

agent.go             |  25 +-
agent_stream_test.go |  13 
agent_test.go        | 132 +++++++-------
content.go           |   3 
tool.go              | 414 +++++++++++++++++++++++++++-------------------
tool_test.go         | 211 +++++++++++++++++++++++
tools/common.go      |  15 +
tools/ls.go          | 258 ++++++++++++++++++++++++++++
8 files changed, 814 insertions(+), 257 deletions(-)

Detailed changes

agent.go 🔗

@@ -8,8 +8,6 @@ import (
 	"maps"
 	"slices"
 	"sync"
-
-	"github.com/charmbracelet/crush/internal/llm/tools"
 )
 
 type StepResult struct {
@@ -100,7 +98,7 @@ type PrepareStepResult struct {
 type ToolCallRepairOptions struct {
 	OriginalToolCall ToolCallContent
 	ValidationError  error
-	AvailableTools   []tools.BaseTool
+	AvailableTools   []AgentTool
 	SystemPrompt     string
 	Messages         []Message
 }
@@ -123,7 +121,7 @@ type AgentSettings struct {
 	providerOptions  ProviderOptions
 
 	// TODO: add support for provider tools
-	tools      []tools.BaseTool
+	tools      []AgentTool
 	maxRetries *int
 
 	model LanguageModel
@@ -548,13 +546,13 @@ func toResponseMessages(content []Content) []Message {
 	return messages
 }
 
-func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
 	if len(toolCalls) == 0 {
 		return nil, nil
 	}
 
 	// Create a map for quick tool lookup
-	toolMap := make(map[string]tools.BaseTool)
+	toolMap := make(map[string]AgentTool)
 	for _, tool := range allTools {
 		toolMap[tool.Info().Name] = tool
 	}
@@ -604,7 +602,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 			}
 
 			// Execute the tool
-			result, err := tool.Run(ctx, tools.ToolCall{
+			result, err := tool.Run(ctx, ToolCall{
 				ID:    call.ToolCallID,
 				Name:  call.ToolName,
 				Input: call.Input,
@@ -616,6 +614,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					Result: ToolResultOutputContentError{
 						Error: err,
 					},
+					ClientMetadata:   result.Metadata,
 					ProviderExecuted: false,
 				}
 				if toolResultCallback != nil {
@@ -632,6 +631,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					Result: ToolResultOutputContentError{
 						Error: errors.New(result.Content),
 					},
+					ClientMetadata:   result.Metadata,
 					ProviderExecuted: false,
 				}
 
@@ -645,6 +645,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					Result: ToolResultOutputContentText{
 						Text: result.Content,
 					},
+					ClientMetadata:   result.Metadata,
 					ProviderExecuted: false,
 				}
 				if toolResultCallback != nil {
@@ -821,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 	return agentResult, nil
 }
 
-func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
+func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 	var preparedTools []Tool
 
 	// If explicitly disabling all tools, return no tools
@@ -850,7 +851,7 @@ func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disab
 }
 
 // validateAndRepairToolCall validates a tool call and attempts repair if validation fails
-func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
+func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 		return toolCall
 	} else {
@@ -878,8 +879,8 @@ func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCall
 }
 
 // validateToolCall validates a tool call against available tools and their schemas
-func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
-	var tool tools.BaseTool
+func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
+	var tool AgentTool
 	for _, t := range availableTools {
 		if t.Info().Name == toolCall.ToolName {
 			tool = t
@@ -966,7 +967,7 @@ func WithFrequencyPenalty(penalty float64) agentOption {
 	}
 }
 
-func WithTools(tools ...tools.BaseTool) agentOption {
+func WithTools(tools ...AgentTool) agentOption {
 	return func(s *AgentSettings) {
 		s.tools = append(s.tools, tools...)
 	}

agent_stream_test.go 🔗

@@ -6,7 +6,6 @@ import (
 	"fmt"
 	"testing"
 
-	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/stretchr/testify/require"
 )
 
@@ -14,8 +13,8 @@ import (
 type EchoTool struct{}
 
 // Info returns the tool information
-func (e *EchoTool) Info() tools.ToolInfo {
-	return tools.ToolInfo{
+func (e *EchoTool) Info() ToolInfo {
+	return ToolInfo{
 		Name:        "echo",
 		Description: "Echo back the provided message",
 		Parameters: map[string]any{
@@ -29,20 +28,20 @@ func (e *EchoTool) Info() tools.ToolInfo {
 }
 
 // Run executes the echo tool
-func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 	var input struct {
 		Message string `json:"message"`
 	}
 
 	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
-		return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
+		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 	}
 
 	if input.Message == "" {
-		return tools.NewTextErrorResponse("Message cannot be empty"), nil
+		return NewTextErrorResponse("Message cannot be empty"), nil
 	}
 
-	return tools.NewTextResponse("Echo: " + input.Message), nil
+	return NewTextResponse("Echo: " + input.Message), nil
 }
 
 // TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly

agent_test.go 🔗

@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"testing"
 
-	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/stretchr/testify/require"
 )
 
@@ -17,11 +16,11 @@ type mockTool struct {
 	description string
 	parameters  map[string]any
 	required    []string
-	executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
+	executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
 }
 
-func (m *mockTool) Info() tools.ToolInfo {
-	return tools.ToolInfo{
+func (m *mockTool) Info() ToolInfo {
+	return ToolInfo{
 		Name:        m.name,
 		Description: m.description,
 		Parameters:  m.parameters,
@@ -29,11 +28,11 @@ func (m *mockTool) Info() tools.ToolInfo {
 	}
 }
 
-func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 	if m.executeFunc != nil {
 		return m.executeFunc(ctx, call)
 	}
-	return tools.ToolResponse{Content: "mock result", IsError: false}, nil
+	return ToolResponse{Content: "mock result", IsError: false}, nil
 }
 
 // Mock language model for testing
@@ -78,22 +77,20 @@ func (m *mockLanguageModel) Model() string {
 func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
 	t.Parallel()
 
-	tool1 := &mockTool{
-		name:        "tool1",
-		description: "Test tool",
-		parameters: map[string]any{
-			"value": map[string]any{"type": "string"},
-		},
-		required: []string{"value"},
-		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
-			var input map[string]any
-			err := json.Unmarshal([]byte(call.Input), &input)
-			require.NoError(t, err)
-			require.Equal(t, "value", input["value"])
-			return tools.ToolResponse{Content: "result1", IsError: false}, nil
-		},
+	// Create a type-safe tool using the new API
+	type TestInput struct {
+		Value string `json:"value" description:"Test value"`
 	}
 
+	tool1 := NewTypedToolFunc(
+		"tool1",
+		"Test tool",
+		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+			require.Equal(t, "value", input.Value)
+			return ToolResponse{Content: "result1", IsError: false}, nil
+		},
+	)
+
 	model := &mockLanguageModel{
 		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 			return &Response{
@@ -213,24 +210,31 @@ func TestAgent_Generate_ResultText(t *testing.T) {
 func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 	t.Parallel()
 
-	tool1 := &mockTool{
-		name:        "tool1",
-		description: "Test tool 1",
-		parameters: map[string]any{
-			"value": map[string]any{"type": "string"},
-		},
-		required: []string{"value"},
+	// Create type-safe tools using the new API
+	type Tool1Input struct {
+		Value string `json:"value" description:"Test value"`
 	}
 
-	tool2 := &mockTool{
-		name:        "tool2",
-		description: "Test tool 2",
-		parameters: map[string]any{
-			"somethingElse": map[string]any{"type": "string"},
-		},
-		required: []string{"somethingElse"},
+	type Tool2Input struct {
+		SomethingElse string `json:"somethingElse" description:"Another test value"`
 	}
 
+	tool1 := NewTypedToolFunc(
+		"tool1",
+		"Test tool 1",
+		func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
+			return ToolResponse{Content: "result1", IsError: false}, nil
+		},
+	)
+
+	tool2 := NewTypedToolFunc(
+		"tool2",
+		"Test tool 2",
+		func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
+			return ToolResponse{Content: "result2", IsError: false}, nil
+		},
+	)
+
 	model := &mockLanguageModel{
 		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 			// Verify tools are passed correctly
@@ -291,22 +295,20 @@ func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 func TestAgent_Generate_ResultToolResults(t *testing.T) {
 	t.Parallel()
 
-	tool1 := &mockTool{
-		name:        "tool1",
-		description: "Test tool",
-		parameters: map[string]any{
-			"value": map[string]any{"type": "string"},
-		},
-		required: []string{"value"},
-		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
-			var input map[string]any
-			err := json.Unmarshal([]byte(call.Input), &input)
-			require.NoError(t, err)
-			require.Equal(t, "value", input["value"])
-			return tools.ToolResponse{Content: "result1", IsError: false}, nil
-		},
+	// Create type-safe tool using the new API
+	type TestInput struct {
+		Value string `json:"value" description:"Test value"`
 	}
 
+	tool1 := NewTypedToolFunc(
+		"tool1",
+		"Test tool",
+		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+			require.Equal(t, "value", input.Value)
+			return ToolResponse{Content: "result1", IsError: false}, nil
+		},
+	)
+
 	model := &mockLanguageModel{
 		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 			// Verify tools and tool choice
@@ -366,22 +368,20 @@ func TestAgent_Generate_ResultToolResults(t *testing.T) {
 func TestAgent_Generate_MultipleSteps(t *testing.T) {
 	t.Parallel()
 
-	tool1 := &mockTool{
-		name:        "tool1",
-		description: "Test tool",
-		parameters: map[string]any{
-			"value": map[string]any{"type": "string"},
-		},
-		required: []string{"value"},
-		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
-			var input map[string]any
-			err := json.Unmarshal([]byte(call.Input), &input)
-			require.NoError(t, err)
-			require.Equal(t, "value", input["value"])
-			return tools.ToolResponse{Content: "result1", IsError: false}, nil
-		},
+	// Create type-safe tool using the new API
+	type TestInput struct {
+		Value string `json:"value" description:"Test value"`
 	}
 
+	tool1 := NewTypedToolFunc(
+		"tool1",
+		"Test tool",
+		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+			require.Equal(t, "value", input.Value)
+			return ToolResponse{Content: "result1", IsError: false}, nil
+		},
+	)
+
 	callCount := 0
 	model := &mockLanguageModel{
 		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
@@ -1289,8 +1289,8 @@ func TestToolCallRepair(t *testing.T) {
 				"value": map[string]any{"type": "string"},
 			},
 			required: []string{"value"},
-			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
-				return tools.ToolResponse{Content: "success", IsError: false}, nil
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				return ToolResponse{Content: "success", IsError: false}, nil
 			},
 		}
 
@@ -1379,8 +1379,8 @@ func TestToolCallRepair(t *testing.T) {
 				"value": map[string]any{"type": "string"},
 			},
 			required: []string{"value"},
-			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
-				return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				return ToolResponse{Content: "repaired_success", IsError: false}, nil
 			},
 		}
 

content.go 🔗

@@ -334,7 +334,8 @@ type ToolResultContent struct {
 	// Name of the tool that generated this result.
 	ToolName string `json:"tool_name"`
 	// Result of the tool call. This is a JSON-serializable object.
-	Result ToolResultOutputContent `json:"result"`
+	Result         ToolResultOutputContent `json:"result"`
+	ClientMetadata string                  `json:"client_metadata"` // Metadata from the client that executed the tool
 	// Whether the tool result was generated by the provider.
 	// If this flag is set to true, the tool result was generated by the provider.
 	// If this flag is not set or is false, the tool result was generated by the client.

tool.go 🔗

@@ -1,20 +1,13 @@
-// WIP NEED TO REVISIT
 package ai
 
 import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"reflect"
+	"strings"
 )
 
-// AgentTool represents a function that can be called by a language model.
-type AgentTool interface {
-	Name() string
-	Description() string
-	InputSchema() Schema
-	Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error)
-}
-
 // Schema represents a JSON schema for tool input validation.
 type Schema struct {
 	Type        string             `json:"type"`
@@ -30,205 +23,284 @@ type Schema struct {
 	MaxLength   *int               `json:"maxLength,omitempty"`
 }
 
-// BasicTool provides a basic implementation of the Tool interface
-//
-// Example usage:
-//
-//	calculator := &tools.BasicTool{
-//	    ToolName:        "calculate",
-//	    ToolDescription: "Evaluates mathematical expressions",
-//	    ToolInputSchema: tools.Schema{
-//	        Type: "object",
-//	        Properties: map[string]*tools.Schema{
-//	            "expression": {
-//	                Type:        "string",
-//	                Description: "Mathematical expression to evaluate",
-//	            },
-//	        },
-//	        Required: []string{"expression"},
-//	    },
-//	    ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
-//	        var req struct {
-//	            Expression string `json:"expression"`
-//	        }
-//	        if err := json.Unmarshal(input, &req); err != nil {
-//	            return nil, err
-//	        }
-//	        result := evaluateExpression(req.Expression)
-//	        return json.Marshal(map[string]any{"result": result})
-//	    },
-//	}
-type BasicTool struct {
-	ToolName        string
-	ToolDescription string
-	ToolInputSchema Schema
-	ExecuteFunc     func(context.Context, json.RawMessage) (json.RawMessage, error)
-}
-
-// Name returns the tool name.
-func (t *BasicTool) Name() string {
-	return t.ToolName
-}
-
-// Description returns the tool description.
-func (t *BasicTool) Description() string {
-	return t.ToolDescription
-}
-
-// InputSchema returns the tool input schema.
-func (t *BasicTool) InputSchema() Schema {
-	return t.ToolInputSchema
-}
-
-// Execute executes the tool with the given input.
-func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
-	if t.ExecuteFunc == nil {
-		return nil, fmt.Errorf("tool %s has no execute function", t.ToolName)
-	}
-	return t.ExecuteFunc(ctx, input)
+// ToolInfo represents tool metadata, matching the existing pattern.
+type ToolInfo struct {
+	Name        string
+	Description string
+	Parameters  map[string]any
+	Required    []string
 }
 
-// ToolBuilder provides a fluent interface for building tools.
-type ToolBuilder struct {
-	tool *BasicTool
+// ToolCall represents a tool invocation, matching the existing pattern.
+type ToolCall struct {
+	ID    string `json:"id"`
+	Name  string `json:"name"`
+	Input string `json:"input"`
 }
 
-// NewTool creates a new tool builder.
-func NewTool(name string) *ToolBuilder {
-	return &ToolBuilder{
-		tool: &BasicTool{
-			ToolName: name,
-		},
+// ToolResponse represents the response from a tool execution, matching the existing pattern.
+type ToolResponse struct {
+	Type     string `json:"type"`
+	Content  string `json:"content"`
+	Metadata string `json:"metadata,omitempty"`
+	IsError  bool   `json:"is_error"`
+}
+
+// NewTextResponse creates a text response.
+func NewTextResponse(content string) ToolResponse {
+	return ToolResponse{
+		Type:    "text",
+		Content: content,
 	}
 }
 
-// Description sets the tool description.
-func (b *ToolBuilder) Description(desc string) *ToolBuilder {
-	b.tool.ToolDescription = desc
-	return b
+// NewTextErrorResponse creates an error response.
+func NewTextErrorResponse(content string) ToolResponse {
+	return ToolResponse{
+		Type:    "text",
+		Content: content,
+		IsError: true,
+	}
 }
 
-// InputSchema sets the tool input schema.
-func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder {
-	b.tool.ToolInputSchema = schema
-	return b
+// WithResponseMetadata adds metadata to a response.
+func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
+	if metadata != nil {
+		metadataBytes, err := json.Marshal(metadata)
+		if err != nil {
+			return response
+		}
+		response.Metadata = string(metadataBytes)
+	}
+	return response
 }
 
-// Execute sets the tool execution function.
-func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder {
-	b.tool.ExecuteFunc = fn
-	return b
+// AgentTool represents a tool that can be called by a language model.
+// This matches the existing BaseTool interface pattern.
+type AgentTool interface {
+	Info() ToolInfo
+	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 }
 
-// Build creates the final tool.
-func (b *ToolBuilder) Build() AgentTool {
-	return b.tool
+// NewTypedToolFunc creates a typed tool from a function with automatic schema generation.
+// This is the recommended way to create tools.
+func NewTypedToolFunc[TInput any](
+	name string,
+	description string,
+	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
+) AgentTool {
+	var input TInput
+	schema := generateSchema(reflect.TypeOf(input))
+
+	return &funcToolWrapper[TInput]{
+		name:        name,
+		description: description,
+		fn:          fn,
+		schema:      schema,
+	}
 }
 
-// SchemaBuilder provides a fluent interface for building JSON schemas.
-type SchemaBuilder struct {
-	schema Schema
+// funcToolWrapper wraps a function to implement the AgentTool interface.
+type funcToolWrapper[TInput any] struct {
+	name        string
+	description string
+	fn          func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
+	schema      Schema
 }
 
-// NewSchema creates a new schema builder.
-func NewSchema(schemaType string) *SchemaBuilder {
-	return &SchemaBuilder{
-		schema: Schema{
-			Type: schemaType,
-		},
+func (w *funcToolWrapper[TInput]) Info() ToolInfo {
+	return ToolInfo{
+		Name:        w.name,
+		Description: w.description,
+		Parameters:  schemaToParameters(w.schema),
+		Required:    w.schema.Required,
 	}
 }
 
-// Object creates a schema builder for an object type.
-func Object() *SchemaBuilder {
-	return NewSchema("object")
-}
+func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
+	var input TInput
+	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil
+	}
 
-// String creates a schema builder for a string type.
-func String() *SchemaBuilder {
-	return NewSchema("string")
+	return w.fn(ctx, input, params)
 }
 
-// Number creates a schema builder for a number type.
-func Number() *SchemaBuilder {
-	return NewSchema("number")
-}
+// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo.
+func schemaToParameters(schema Schema) map[string]any {
+	if schema.Type != "object" || schema.Properties == nil {
+		return map[string]any{}
+	}
 
-// Array creates a schema builder for an array type.
-func Array() *SchemaBuilder {
-	return NewSchema("array")
-}
+	params := make(map[string]any)
+	for name, propSchema := range schema.Properties {
+		param := map[string]any{
+			"type": propSchema.Type,
+		}
 
-// Description sets the schema description.
-func (b *SchemaBuilder) Description(desc string) *SchemaBuilder {
-	b.schema.Description = desc
-	return b
-}
+		if propSchema.Description != "" {
+			param["description"] = propSchema.Description
+		}
 
-// Properties sets the schema properties.
-func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder {
-	b.schema.Properties = props
-	return b
-}
+		if len(propSchema.Enum) > 0 {
+			param["enum"] = propSchema.Enum
+		}
 
-// Property adds a property to the schema.
-func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder {
-	if b.schema.Properties == nil {
-		b.schema.Properties = make(map[string]*Schema)
-	}
-	b.schema.Properties[name] = schema
-	return b
-}
+		if propSchema.Format != "" {
+			param["format"] = propSchema.Format
+		}
 
-// Required marks fields as required.
-func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder {
-	b.schema.Required = append(b.schema.Required, fields...)
-	return b
-}
+		if propSchema.Minimum != nil {
+			param["minimum"] = *propSchema.Minimum
+		}
 
-// Items sets the schema for array items.
-func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder {
-	b.schema.Items = schema
-	return b
-}
+		if propSchema.Maximum != nil {
+			param["maximum"] = *propSchema.Maximum
+		}
 
-// Enum sets allowed values for the schema.
-func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder {
-	b.schema.Enum = values
-	return b
-}
+		if propSchema.MinLength != nil {
+			param["minLength"] = *propSchema.MinLength
+		}
 
-// Format sets the string format.
-func (b *SchemaBuilder) Format(format string) *SchemaBuilder {
-	b.schema.Format = format
-	return b
-}
+		if propSchema.MaxLength != nil {
+			param["maxLength"] = *propSchema.MaxLength
+		}
 
-// Min sets the minimum value.
-func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder {
-	b.schema.Minimum = &minimum
-	return b
-}
+		if propSchema.Items != nil {
+			param["items"] = schemaToParameters(*propSchema.Items)
+		}
 
-// Max sets the maximum value.
-func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder {
-	b.schema.Maximum = &maximum
-	return b
+		params[name] = param
+	}
+
+	return params
 }
 
-// MinLength sets the minimum string length.
-func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder {
-	b.schema.MinLength = &minimum
-	return b
+// generateSchema automatically generates a JSON schema from a Go type.
+func generateSchema(t reflect.Type) Schema {
+	return generateSchemaRecursive(t, make(map[reflect.Type]bool))
 }
 
-// MaxLength sets the maximum string length.
-func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder {
-	b.schema.MaxLength = &maximum
-	return b
+func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema {
+	// Handle pointers
+	if t.Kind() == reflect.Pointer {
+		t = t.Elem()
+	}
+
+	// Prevent infinite recursion
+	if visited[t] {
+		return Schema{Type: "object"}
+	}
+	visited[t] = true
+	defer delete(visited, t)
+
+	switch t.Kind() {
+	case reflect.String:
+		return Schema{Type: "string"}
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		return Schema{Type: "integer"}
+	case reflect.Float32, reflect.Float64:
+		return Schema{Type: "number"}
+	case reflect.Bool:
+		return Schema{Type: "boolean"}
+	case reflect.Slice, reflect.Array:
+		itemSchema := generateSchemaRecursive(t.Elem(), visited)
+		return Schema{
+			Type:  "array",
+			Items: &itemSchema,
+		}
+	case reflect.Map:
+		if t.Key().Kind() == reflect.String {
+			valueSchema := generateSchemaRecursive(t.Elem(), visited)
+			return Schema{
+				Type: "object",
+				Properties: map[string]*Schema{
+					"*": &valueSchema,
+				},
+			}
+		}
+		return Schema{Type: "object"}
+	case reflect.Struct:
+		schema := Schema{
+			Type:       "object",
+			Properties: make(map[string]*Schema),
+		}
+
+		for i := range t.NumField() {
+			field := t.Field(i)
+
+			// Skip unexported fields
+			if !field.IsExported() {
+				continue
+			}
+
+			jsonTag := field.Tag.Get("json")
+			if jsonTag == "-" {
+				continue
+			}
+
+			fieldName := field.Name
+			required := true
+
+			// Parse JSON tag
+			if jsonTag != "" {
+				parts := strings.Split(jsonTag, ",")
+				if parts[0] != "" {
+					fieldName = parts[0]
+				}
+
+				// Check for omitempty
+				for _, part := range parts[1:] {
+					if part == "omitempty" {
+						required = false
+						break
+					}
+				}
+			} else {
+				// Convert field name to snake_case for JSON
+				fieldName = toSnakeCase(fieldName)
+			}
+
+			fieldSchema := generateSchemaRecursive(field.Type, visited)
+
+			// Add description from struct tag if available
+			if desc := field.Tag.Get("description"); desc != "" {
+				fieldSchema.Description = desc
+			}
+
+			// Add enum values from struct tag if available
+			if enumTag := field.Tag.Get("enum"); enumTag != "" {
+				enumValues := strings.Split(enumTag, ",")
+				fieldSchema.Enum = make([]any, len(enumValues))
+				for i, v := range enumValues {
+					fieldSchema.Enum[i] = strings.TrimSpace(v)
+				}
+			}
+
+			schema.Properties[fieldName] = &fieldSchema
+
+			if required {
+				schema.Required = append(schema.Required, fieldName)
+			}
+		}
+
+		return schema
+	case reflect.Interface:
+		return Schema{Type: "object"}
+	default:
+		return Schema{Type: "object"}
+	}
 }
 
-// Build creates the final schema.
-func (b *SchemaBuilder) Build() *Schema {
-	return &b.schema
+// toSnakeCase converts PascalCase to snake_case.
+func toSnakeCase(s string) string {
+	var result strings.Builder
+	for i, r := range s {
+		if i > 0 && r >= 'A' && r <= 'Z' {
+			result.WriteByte('_')
+		}
+		result.WriteRune(r)
+	}
+	return strings.ToLower(result.String())
 }

tool_test.go 🔗

@@ -0,0 +1,211 @@
+package ai
+
+import (
+	"context"
+	"fmt"
+	"reflect"
+	"strings"
+	"testing"
+)
+
+// Example of a simple typed tool using the function approach
+type CalculatorInput struct {
+	Expression string `json:"expression" description:"Mathematical expression to evaluate"`
+}
+
+func TestTypedToolFuncExample(t *testing.T) {
+	// Create a typed tool using the function API
+	tool := NewTypedToolFunc(
+		"calculator",
+		"Evaluates simple mathematical expressions",
+		func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) {
+			if input.Expression == "2+2" {
+				return NewTextResponse("4"), nil
+			}
+			return NewTextErrorResponse("unsupported expression"), nil
+		},
+	)
+
+	// Check the tool info
+	info := tool.Info()
+	if info.Name != "calculator" {
+		t.Errorf("Expected tool name 'calculator', got %s", info.Name)
+	}
+	if len(info.Required) != 1 || info.Required[0] != "expression" {
+		t.Errorf("Expected required field 'expression', got %v", info.Required)
+	}
+
+	// Test execution
+	call := ToolCall{
+		ID:    "test-1",
+		Name:  "calculator",
+		Input: `{"expression": "2+2"}`,
+	}
+
+	result, err := tool.Run(context.Background(), call)
+	if err != nil {
+		t.Errorf("Unexpected error: %v", err)
+	}
+	if result.Content != "4" {
+		t.Errorf("Expected result '4', got %s", result.Content)
+	}
+	if result.IsError {
+		t.Errorf("Expected successful result, got error")
+	}
+}
+
+func TestEnumToolExample(t *testing.T) {
+	type WeatherInput struct {
+		Location string `json:"location" description:"City name"`
+		Units    string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"`
+	}
+
+	// Create a weather tool with enum support
+	tool := NewTypedToolFunc(
+		"weather",
+		"Gets current weather for a location",
+		func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) {
+			temp := "22°C"
+			if input.Units == "fahrenheit" {
+				temp = "72°F"
+			}
+			return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil
+		},
+	)
+
+	// Check that the schema includes enum values
+	info := tool.Info()
+	unitsParam, ok := info.Parameters["units"].(map[string]any)
+	if !ok {
+		t.Fatal("Expected units parameter to exist")
+	}
+	enumValues, ok := unitsParam["enum"].([]any)
+	if !ok || len(enumValues) != 2 {
+		t.Errorf("Expected 2 enum values, got %v", enumValues)
+	}
+
+	// Test execution with enum value
+	call := ToolCall{
+		ID:    "test-2",
+		Name:  "weather",
+		Input: `{"location": "San Francisco", "units": "fahrenheit"}`,
+	}
+
+	result, err := tool.Run(context.Background(), call)
+	if err != nil {
+		t.Errorf("Unexpected error: %v", err)
+	}
+	if !strings.Contains(result.Content, "San Francisco") {
+		t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
+	}
+	if !strings.Contains(result.Content, "72°F") {
+		t.Errorf("Expected result to contain '72°F', got %s", result.Content)
+	}
+}
+
+func TestEnumSupport(t *testing.T) {
+	// Test enum via struct tags
+	type WeatherInput struct {
+		Location string `json:"location" description:"City name"`
+		Units    string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"`
+		Format   string `json:"format,omitempty" enum:"json,xml,text"`
+	}
+
+	schema := generateSchema(reflect.TypeOf(WeatherInput{}))
+
+	if schema.Type != "object" {
+		t.Errorf("Expected schema type 'object', got %s", schema.Type)
+	}
+
+	// Check units field has enum values
+	unitsSchema := schema.Properties["units"]
+	if unitsSchema == nil {
+		t.Fatal("Expected units property to exist")
+	}
+	if len(unitsSchema.Enum) != 3 {
+		t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
+	}
+	expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
+	for i, expected := range expectedUnits {
+		if unitsSchema.Enum[i] != expected {
+			t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
+		}
+	}
+
+	// Check required fields (format should not be required due to omitempty)
+	expectedRequired := []string{"location", "units"}
+	if len(schema.Required) != len(expectedRequired) {
+		t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
+	}
+}
+
+func TestSchemaToParameters(t *testing.T) {
+	schema := Schema{
+		Type: "object",
+		Properties: map[string]*Schema{
+			"name": {
+				Type:        "string",
+				Description: "The name field",
+			},
+			"age": {
+				Type:    "integer",
+				Minimum: func() *float64 { v := 0.0; return &v }(),
+				Maximum: func() *float64 { v := 120.0; return &v }(),
+			},
+			"tags": {
+				Type: "array",
+				Items: &Schema{
+					Type: "string",
+				},
+			},
+			"priority": {
+				Type: "string",
+				Enum: []any{"low", "medium", "high"},
+			},
+		},
+		Required: []string{"name"},
+	}
+
+	params := schemaToParameters(schema)
+
+	// Check name parameter
+	nameParam, ok := params["name"].(map[string]any)
+	if !ok {
+		t.Fatal("Expected name parameter to exist")
+	}
+	if nameParam["type"] != "string" {
+		t.Errorf("Expected name type 'string', got %v", nameParam["type"])
+	}
+	if nameParam["description"] != "The name field" {
+		t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
+	}
+
+	// Check age parameter with min/max
+	ageParam, ok := params["age"].(map[string]any)
+	if !ok {
+		t.Fatal("Expected age parameter to exist")
+	}
+	if ageParam["type"] != "integer" {
+		t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
+	}
+	if ageParam["minimum"] != 0.0 {
+		t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
+	}
+	if ageParam["maximum"] != 120.0 {
+		t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
+	}
+
+	// Check priority parameter with enum
+	priorityParam, ok := params["priority"].(map[string]any)
+	if !ok {
+		t.Fatal("Expected priority parameter to exist")
+	}
+	if priorityParam["type"] != "string" {
+		t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
+	}
+	enumValues, ok := priorityParam["enum"].([]any)
+	if !ok || len(enumValues) != 3 {
+		t.Errorf("Expected 3 enum values, got %v", enumValues)
+	}
+}
+

tools/common.go 🔗

@@ -0,0 +1,15 @@
+package tools
+
+import "errors"
+
+type Permission struct {
+	ToolCallID  string
+	ToolName    string
+	Path        string
+	Action      string
+	Description string
+	Params      any
+}
+type PermissionAsk = func(ask Permission) bool
+
+var ErrorPermissionDenied = errors.New("permission denied")

tools/ls.go 🔗

@@ -0,0 +1,258 @@
+package tools
+
+import (
+	"context"
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/fsext"
+)
+
+const (
+	MaxLSFiles = 1000
+	LSToolName = "ls"
+)
+
+type LSParams struct {
+	Path   string   `json:"path" description:"The path to the directory to list (defaults to current working directory)"`
+	Ignore []string `json:"ignore,omitempty" description:"List of glob patterns to ignore"`
+}
+
+type LSPermissionsParams struct {
+	Path   string   `json:"path"`
+	Ignore []string `json:"ignore"`
+}
+
+type TreeNode struct {
+	Name     string      `json:"name"`
+	Path     string      `json:"path"`
+	Type     string      `json:"type"`
+	Children []*TreeNode `json:"children,omitempty"`
+}
+
+type LSResponseMetadata struct {
+	NumberOfFiles int  `json:"number_of_files"`
+	Truncated     bool `json:"truncated"`
+}
+
+func NewLSTool(permissionAsk PermissionAsk, workingDir string) ai.AgentTool {
+	return ai.NewTypedToolFunc(
+		LSToolName,
+		`Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to explore the structure of a directory
+- Helpful for understanding the organization of a project
+- Good first step when getting familiar with a new codebase
+
+HOW TO USE:
+- Provide a path to list (defaults to current working directory)
+- Optionally specify glob patterns to ignore
+- Results are displayed in a tree structure
+
+FEATURES:
+- Displays a hierarchical view of files and directories
+- Automatically skips hidden files/directories (starting with '.')
+- Skips common system directories like __pycache__
+- Can filter out files matching specific patterns
+
+LIMITATIONS:
+- Results are limited to 1000 files
+- Very large directories will be truncated
+- Does not show file sizes or permissions
+- Cannot recursively list all directories in a large project
+
+WINDOWS NOTES:
+- Hidden file detection uses Unix convention (files starting with '.')
+- Windows-specific hidden files (with hidden attribute) are not automatically skipped
+- Common Windows directories like System32, Program Files are not in default ignore list
+- Path separators are handled automatically (both / and \ work)
+
+TIPS:
+- Use Glob tool for finding files by name patterns instead of browsing
+- Use Grep tool for searching file contents
+- Combine with other tools for more effective exploration`,
+		func(ctx context.Context, params LSParams, call ai.ToolCall) (ai.ToolResponse, error) {
+			searchPath := params.Path
+			if searchPath == "" {
+				searchPath = workingDir
+			}
+
+			var err error
+			searchPath, err = fsext.Expand(searchPath)
+			if err != nil {
+				return ai.ToolResponse{}, fmt.Errorf("error expanding path: %w", err)
+			}
+
+			if !filepath.IsAbs(searchPath) {
+				searchPath = filepath.Join(workingDir, searchPath)
+			}
+
+			// Check if directory is outside working directory and request permission if needed
+			absWorkingDir, err := filepath.Abs(workingDir)
+			if err != nil {
+				return ai.ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err)
+			}
+
+			absSearchPath, err := filepath.Abs(searchPath)
+			if err != nil {
+				return ai.ToolResponse{}, fmt.Errorf("error resolving search path: %w", err)
+			}
+
+			relPath, err := filepath.Rel(absWorkingDir, absSearchPath)
+			if err != nil || strings.HasPrefix(relPath, "..") {
+				granted := permissionAsk(Permission{
+					ToolCallID:  call.ID,
+					ToolName:    LSToolName,
+					Path:        absSearchPath,
+					Action:      "list",
+					Description: fmt.Sprintf("List directory outside working directory: %s", absSearchPath),
+					Params:      LSPermissionsParams(params),
+				})
+
+				if !granted {
+					return ai.ToolResponse{}, ErrorPermissionDenied
+				}
+			}
+			output, err := ListDirectoryTree(searchPath, params.Ignore)
+			if err != nil {
+				return ai.ToolResponse{}, err
+			}
+
+			// Get file count for metadata
+			files, truncated, err := fsext.ListDirectory(searchPath, params.Ignore, MaxLSFiles)
+			if err != nil {
+				return ai.ToolResponse{}, fmt.Errorf("error listing directory for metadata: %w", err)
+			}
+
+			return ai.WithResponseMetadata(
+				ai.NewTextResponse(output),
+				LSResponseMetadata{
+					NumberOfFiles: len(files),
+					Truncated:     truncated,
+				},
+			), nil
+		},
+	)
+}
+
+func ListDirectoryTree(searchPath string, ignore []string) (string, error) {
+	if _, err := os.Stat(searchPath); os.IsNotExist(err) {
+		return "", fmt.Errorf("path does not exist: %s", searchPath)
+	}
+
+	files, truncated, err := fsext.ListDirectory(searchPath, ignore, MaxLSFiles)
+	if err != nil {
+		return "", fmt.Errorf("error listing directory: %w", err)
+	}
+
+	tree := createFileTree(files, searchPath)
+	output := printTree(tree, searchPath)
+
+	if truncated {
+		output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
+	}
+
+	return output, nil
+}
+
+func createFileTree(sortedPaths []string, rootPath string) []*TreeNode {
+	root := []*TreeNode{}
+	pathMap := make(map[string]*TreeNode)
+
+	for _, path := range sortedPaths {
+		relativePath := strings.TrimPrefix(path, rootPath)
+		parts := strings.Split(relativePath, string(filepath.Separator))
+		currentPath := ""
+		var parentPath string
+
+		var cleanParts []string
+		for _, part := range parts {
+			if part != "" {
+				cleanParts = append(cleanParts, part)
+			}
+		}
+		parts = cleanParts
+
+		if len(parts) == 0 {
+			continue
+		}
+
+		for i, part := range parts {
+			if currentPath == "" {
+				currentPath = part
+			} else {
+				currentPath = filepath.Join(currentPath, part)
+			}
+
+			if _, exists := pathMap[currentPath]; exists {
+				parentPath = currentPath
+				continue
+			}
+
+			isLastPart := i == len(parts)-1
+			isDir := !isLastPart || strings.HasSuffix(relativePath, string(filepath.Separator))
+			nodeType := "file"
+			if isDir {
+				nodeType = "directory"
+			}
+			newNode := &TreeNode{
+				Name:     part,
+				Path:     currentPath,
+				Type:     nodeType,
+				Children: []*TreeNode{},
+			}
+
+			pathMap[currentPath] = newNode
+
+			if i > 0 && parentPath != "" {
+				if parent, ok := pathMap[parentPath]; ok {
+					parent.Children = append(parent.Children, newNode)
+				}
+			} else {
+				root = append(root, newNode)
+			}
+
+			parentPath = currentPath
+		}
+	}
+
+	return root
+}
+
+func printTree(tree []*TreeNode, rootPath string) string {
+	var result strings.Builder
+
+	result.WriteString("- ")
+	result.WriteString(rootPath)
+	if rootPath[len(rootPath)-1] != '/' {
+		result.WriteByte(filepath.Separator)
+	}
+	result.WriteByte('\n')
+
+	for _, node := range tree {
+		printNode(&result, node, 1)
+	}
+
+	return result.String()
+}
+
+func printNode(builder *strings.Builder, node *TreeNode, level int) {
+	indent := strings.Repeat("  ", level)
+
+	nodeName := node.Name
+	if node.Type == "directory" {
+		nodeName = nodeName + string(filepath.Separator)
+	}
+
+	fmt.Fprintf(builder, "%s- %s\n", indent, nodeName)
+
+	if node.Type == "directory" && len(node.Children) > 0 {
+		for _, child := range node.Children {
+			printNode(builder, child, level+1)
+		}
+	}
+}