From d8f2f21449b1eb5e1088b82330c66b86a576c802 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 22 Aug 2025 14:18:14 +0200 Subject: [PATCH] feat: add new tool definitions --- agent.go | 25 +-- agent_stream_test.go | 13 +- agent_test.go | 132 +++++++------- content.go | 3 +- tool.go | 414 +++++++++++++++++++++++++------------------ tool_test.go | 211 ++++++++++++++++++++++ tools/common.go | 15 ++ tools/ls.go | 258 +++++++++++++++++++++++++++ 8 files changed, 814 insertions(+), 257 deletions(-) create mode 100644 tool_test.go create mode 100644 tools/common.go create mode 100644 tools/ls.go diff --git a/agent.go b/agent.go index fb84e79b9752f3d5eafecd267a0e7a85e8da9650..6fe6af5fbd458b49535022dba423060df2012157 100644 --- a/agent.go +++ b/agent.go @@ -8,8 +8,6 @@ import ( "maps" "slices" "sync" - - "github.com/charmbracelet/crush/internal/llm/tools" ) type StepResult struct { @@ -100,7 +98,7 @@ type PrepareStepResult struct { type ToolCallRepairOptions struct { OriginalToolCall ToolCallContent ValidationError error - AvailableTools []tools.BaseTool + AvailableTools []AgentTool SystemPrompt string Messages []Message } @@ -123,7 +121,7 @@ type AgentSettings struct { providerOptions ProviderOptions // TODO: add support for provider tools - tools []tools.BaseTool + tools []AgentTool maxRetries *int model LanguageModel @@ -548,13 +546,13 @@ func toResponseMessages(content []Content) []Message { return messages } -func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) { +func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) { if len(toolCalls) == 0 { return nil, nil } // Create a map for quick tool lookup - toolMap := make(map[string]tools.BaseTool) + toolMap := make(map[string]AgentTool) for _, tool := range allTools { toolMap[tool.Info().Name] = tool } @@ -604,7 +602,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too } // Execute the tool - result, err := tool.Run(ctx, tools.ToolCall{ + result, err := tool.Run(ctx, ToolCall{ ID: call.ToolCallID, Name: call.ToolName, Input: call.Input, @@ -616,6 +614,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too Result: ToolResultOutputContentError{ Error: err, }, + ClientMetadata: result.Metadata, ProviderExecuted: false, } if toolResultCallback != nil { @@ -632,6 +631,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too Result: ToolResultOutputContentError{ Error: errors.New(result.Content), }, + ClientMetadata: result.Metadata, ProviderExecuted: false, } @@ -645,6 +645,7 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too Result: ToolResultOutputContentText{ Text: result.Content, }, + ClientMetadata: result.Metadata, ProviderExecuted: false, } if toolResultCallback != nil { @@ -821,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, return agentResult, nil } -func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool { +func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool { var preparedTools []Tool // If explicitly disabling all tools, return no tools @@ -850,7 +851,7 @@ func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disab } // validateAndRepairToolCall validates a tool call and attempts repair if validation fails -func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent { +func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent { if err := a.validateToolCall(toolCall, availableTools); err == nil { return toolCall } else { @@ -878,8 +879,8 @@ func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCall } // validateToolCall validates a tool call against available tools and their schemas -func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error { - var tool tools.BaseTool +func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error { + var tool AgentTool for _, t := range availableTools { if t.Info().Name == toolCall.ToolName { tool = t @@ -966,7 +967,7 @@ func WithFrequencyPenalty(penalty float64) agentOption { } } -func WithTools(tools ...tools.BaseTool) agentOption { +func WithTools(tools ...AgentTool) agentOption { return func(s *AgentSettings) { s.tools = append(s.tools, tools...) } diff --git a/agent_stream_test.go b/agent_stream_test.go index 3ddae7431a26f6c721fdd11308ecfb31c83c0b68..7b4aec4d4943560c8903e7d7aad9983444093d6c 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -6,7 +6,6 @@ import ( "fmt" "testing" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/stretchr/testify/require" ) @@ -14,8 +13,8 @@ import ( type EchoTool struct{} // Info returns the tool information -func (e *EchoTool) Info() tools.ToolInfo { - return tools.ToolInfo{ +func (e *EchoTool) Info() ToolInfo { + return ToolInfo{ Name: "echo", Description: "Echo back the provided message", Parameters: map[string]any{ @@ -29,20 +28,20 @@ func (e *EchoTool) Info() tools.ToolInfo { } // Run executes the echo tool -func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { +func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) { var input struct { Message string `json:"message"` } if err := json.Unmarshal([]byte(params.Input), &input); err != nil { - return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil + return NewTextErrorResponse("Invalid input: " + err.Error()), nil } if input.Message == "" { - return tools.NewTextErrorResponse("Message cannot be empty"), nil + return NewTextErrorResponse("Message cannot be empty"), nil } - return tools.NewTextResponse("Echo: " + input.Message), nil + return NewTextResponse("Echo: " + input.Message), nil } // TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly diff --git a/agent_test.go b/agent_test.go index 75b4f8948acadc694bc35a142cfa08994a0bf35e..41632913a7043e2985cd19f3f1a45b4257365691 100644 --- a/agent_test.go +++ b/agent_test.go @@ -7,7 +7,6 @@ import ( "fmt" "testing" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/stretchr/testify/require" ) @@ -17,11 +16,11 @@ type mockTool struct { description string parameters map[string]any required []string - executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) + executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error) } -func (m *mockTool) Info() tools.ToolInfo { - return tools.ToolInfo{ +func (m *mockTool) Info() ToolInfo { + return ToolInfo{ Name: m.name, Description: m.description, Parameters: m.parameters, @@ -29,11 +28,11 @@ func (m *mockTool) Info() tools.ToolInfo { } } -func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { +func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { if m.executeFunc != nil { return m.executeFunc(ctx, call) } - return tools.ToolResponse{Content: "mock result", IsError: false}, nil + return ToolResponse{Content: "mock result", IsError: false}, nil } // Mock language model for testing @@ -78,22 +77,20 @@ func (m *mockLanguageModel) Model() string { func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) { t.Parallel() - tool1 := &mockTool{ - name: "tool1", - description: "Test tool", - parameters: map[string]any{ - "value": map[string]any{"type": "string"}, - }, - required: []string{"value"}, - executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - var input map[string]any - err := json.Unmarshal([]byte(call.Input), &input) - require.NoError(t, err) - require.Equal(t, "value", input["value"]) - return tools.ToolResponse{Content: "result1", IsError: false}, nil - }, + // Create a type-safe tool using the new API + type TestInput struct { + Value string `json:"value" description:"Test value"` } + tool1 := NewTypedToolFunc( + "tool1", + "Test tool", + func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) { + require.Equal(t, "value", input.Value) + return ToolResponse{Content: "result1", IsError: false}, nil + }, + ) + model := &mockLanguageModel{ generateFunc: func(ctx context.Context, call Call) (*Response, error) { return &Response{ @@ -213,24 +210,31 @@ func TestAgent_Generate_ResultText(t *testing.T) { func TestAgent_Generate_ResultToolCalls(t *testing.T) { t.Parallel() - tool1 := &mockTool{ - name: "tool1", - description: "Test tool 1", - parameters: map[string]any{ - "value": map[string]any{"type": "string"}, - }, - required: []string{"value"}, + // Create type-safe tools using the new API + type Tool1Input struct { + Value string `json:"value" description:"Test value"` } - tool2 := &mockTool{ - name: "tool2", - description: "Test tool 2", - parameters: map[string]any{ - "somethingElse": map[string]any{"type": "string"}, - }, - required: []string{"somethingElse"}, + type Tool2Input struct { + SomethingElse string `json:"somethingElse" description:"Another test value"` } + tool1 := NewTypedToolFunc( + "tool1", + "Test tool 1", + func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) { + return ToolResponse{Content: "result1", IsError: false}, nil + }, + ) + + tool2 := NewTypedToolFunc( + "tool2", + "Test tool 2", + func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) { + return ToolResponse{Content: "result2", IsError: false}, nil + }, + ) + model := &mockLanguageModel{ generateFunc: func(ctx context.Context, call Call) (*Response, error) { // Verify tools are passed correctly @@ -291,22 +295,20 @@ func TestAgent_Generate_ResultToolCalls(t *testing.T) { func TestAgent_Generate_ResultToolResults(t *testing.T) { t.Parallel() - tool1 := &mockTool{ - name: "tool1", - description: "Test tool", - parameters: map[string]any{ - "value": map[string]any{"type": "string"}, - }, - required: []string{"value"}, - executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - var input map[string]any - err := json.Unmarshal([]byte(call.Input), &input) - require.NoError(t, err) - require.Equal(t, "value", input["value"]) - return tools.ToolResponse{Content: "result1", IsError: false}, nil - }, + // Create type-safe tool using the new API + type TestInput struct { + Value string `json:"value" description:"Test value"` } + tool1 := NewTypedToolFunc( + "tool1", + "Test tool", + func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) { + require.Equal(t, "value", input.Value) + return ToolResponse{Content: "result1", IsError: false}, nil + }, + ) + model := &mockLanguageModel{ generateFunc: func(ctx context.Context, call Call) (*Response, error) { // Verify tools and tool choice @@ -366,22 +368,20 @@ func TestAgent_Generate_ResultToolResults(t *testing.T) { func TestAgent_Generate_MultipleSteps(t *testing.T) { t.Parallel() - tool1 := &mockTool{ - name: "tool1", - description: "Test tool", - parameters: map[string]any{ - "value": map[string]any{"type": "string"}, - }, - required: []string{"value"}, - executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - var input map[string]any - err := json.Unmarshal([]byte(call.Input), &input) - require.NoError(t, err) - require.Equal(t, "value", input["value"]) - return tools.ToolResponse{Content: "result1", IsError: false}, nil - }, + // Create type-safe tool using the new API + type TestInput struct { + Value string `json:"value" description:"Test value"` } + tool1 := NewTypedToolFunc( + "tool1", + "Test tool", + func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) { + require.Equal(t, "value", input.Value) + return ToolResponse{Content: "result1", IsError: false}, nil + }, + ) + callCount := 0 model := &mockLanguageModel{ generateFunc: func(ctx context.Context, call Call) (*Response, error) { @@ -1289,8 +1289,8 @@ func TestToolCallRepair(t *testing.T) { "value": map[string]any{"type": "string"}, }, required: []string{"value"}, - executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - return tools.ToolResponse{Content: "success", IsError: false}, nil + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + return ToolResponse{Content: "success", IsError: false}, nil }, } @@ -1379,8 +1379,8 @@ func TestToolCallRepair(t *testing.T) { "value": map[string]any{"type": "string"}, }, required: []string{"value"}, - executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + return ToolResponse{Content: "repaired_success", IsError: false}, nil }, } diff --git a/content.go b/content.go index 3dbd77a8323f31a4f8ce92a6ac04ddb75a3fda72..eec2ec539dedc5c55e762ceff8e11721320a0749 100644 --- a/content.go +++ b/content.go @@ -334,7 +334,8 @@ type ToolResultContent struct { // Name of the tool that generated this result. ToolName string `json:"tool_name"` // Result of the tool call. This is a JSON-serializable object. - Result ToolResultOutputContent `json:"result"` + Result ToolResultOutputContent `json:"result"` + ClientMetadata string `json:"client_metadata"` // Metadata from the client that executed the tool // Whether the tool result was generated by the provider. // If this flag is set to true, the tool result was generated by the provider. // If this flag is not set or is false, the tool result was generated by the client. diff --git a/tool.go b/tool.go index b0c7c518c0852cf78f8d3687a284c49a552ff220..8b4d379cc1c2ff81bffc667012a436bb3a514ea6 100644 --- a/tool.go +++ b/tool.go @@ -1,20 +1,13 @@ -// WIP NEED TO REVISIT package ai import ( "context" "encoding/json" "fmt" + "reflect" + "strings" ) -// AgentTool represents a function that can be called by a language model. -type AgentTool interface { - Name() string - Description() string - InputSchema() Schema - Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) -} - // Schema represents a JSON schema for tool input validation. type Schema struct { Type string `json:"type"` @@ -30,205 +23,284 @@ type Schema struct { MaxLength *int `json:"maxLength,omitempty"` } -// BasicTool provides a basic implementation of the Tool interface -// -// Example usage: -// -// calculator := &tools.BasicTool{ -// ToolName: "calculate", -// ToolDescription: "Evaluates mathematical expressions", -// ToolInputSchema: tools.Schema{ -// Type: "object", -// Properties: map[string]*tools.Schema{ -// "expression": { -// Type: "string", -// Description: "Mathematical expression to evaluate", -// }, -// }, -// Required: []string{"expression"}, -// }, -// ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { -// var req struct { -// Expression string `json:"expression"` -// } -// if err := json.Unmarshal(input, &req); err != nil { -// return nil, err -// } -// result := evaluateExpression(req.Expression) -// return json.Marshal(map[string]any{"result": result}) -// }, -// } -type BasicTool struct { - ToolName string - ToolDescription string - ToolInputSchema Schema - ExecuteFunc func(context.Context, json.RawMessage) (json.RawMessage, error) -} - -// Name returns the tool name. -func (t *BasicTool) Name() string { - return t.ToolName -} - -// Description returns the tool description. -func (t *BasicTool) Description() string { - return t.ToolDescription -} - -// InputSchema returns the tool input schema. -func (t *BasicTool) InputSchema() Schema { - return t.ToolInputSchema -} - -// Execute executes the tool with the given input. -func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { - if t.ExecuteFunc == nil { - return nil, fmt.Errorf("tool %s has no execute function", t.ToolName) - } - return t.ExecuteFunc(ctx, input) +// ToolInfo represents tool metadata, matching the existing pattern. +type ToolInfo struct { + Name string + Description string + Parameters map[string]any + Required []string } -// ToolBuilder provides a fluent interface for building tools. -type ToolBuilder struct { - tool *BasicTool +// ToolCall represents a tool invocation, matching the existing pattern. +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Input string `json:"input"` } -// NewTool creates a new tool builder. -func NewTool(name string) *ToolBuilder { - return &ToolBuilder{ - tool: &BasicTool{ - ToolName: name, - }, +// ToolResponse represents the response from a tool execution, matching the existing pattern. +type ToolResponse struct { + Type string `json:"type"` + Content string `json:"content"` + Metadata string `json:"metadata,omitempty"` + IsError bool `json:"is_error"` +} + +// NewTextResponse creates a text response. +func NewTextResponse(content string) ToolResponse { + return ToolResponse{ + Type: "text", + Content: content, } } -// Description sets the tool description. -func (b *ToolBuilder) Description(desc string) *ToolBuilder { - b.tool.ToolDescription = desc - return b +// NewTextErrorResponse creates an error response. +func NewTextErrorResponse(content string) ToolResponse { + return ToolResponse{ + Type: "text", + Content: content, + IsError: true, + } } -// InputSchema sets the tool input schema. -func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder { - b.tool.ToolInputSchema = schema - return b +// WithResponseMetadata adds metadata to a response. +func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse { + if metadata != nil { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return response + } + response.Metadata = string(metadataBytes) + } + return response } -// Execute sets the tool execution function. -func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder { - b.tool.ExecuteFunc = fn - return b +// AgentTool represents a tool that can be called by a language model. +// This matches the existing BaseTool interface pattern. +type AgentTool interface { + Info() ToolInfo + Run(ctx context.Context, params ToolCall) (ToolResponse, error) } -// Build creates the final tool. -func (b *ToolBuilder) Build() AgentTool { - return b.tool +// NewTypedToolFunc creates a typed tool from a function with automatic schema generation. +// This is the recommended way to create tools. +func NewTypedToolFunc[TInput any]( + name string, + description string, + fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error), +) AgentTool { + var input TInput + schema := generateSchema(reflect.TypeOf(input)) + + return &funcToolWrapper[TInput]{ + name: name, + description: description, + fn: fn, + schema: schema, + } } -// SchemaBuilder provides a fluent interface for building JSON schemas. -type SchemaBuilder struct { - schema Schema +// funcToolWrapper wraps a function to implement the AgentTool interface. +type funcToolWrapper[TInput any] struct { + name string + description string + fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error) + schema Schema } -// NewSchema creates a new schema builder. -func NewSchema(schemaType string) *SchemaBuilder { - return &SchemaBuilder{ - schema: Schema{ - Type: schemaType, - }, +func (w *funcToolWrapper[TInput]) Info() ToolInfo { + return ToolInfo{ + Name: w.name, + Description: w.description, + Parameters: schemaToParameters(w.schema), + Required: w.schema.Required, } } -// Object creates a schema builder for an object type. -func Object() *SchemaBuilder { - return NewSchema("object") -} +func (w *funcToolWrapper[TInput]) Run(ctx context.Context, params ToolCall) (ToolResponse, error) { + var input TInput + if err := json.Unmarshal([]byte(params.Input), &input); err != nil { + return NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil + } -// String creates a schema builder for a string type. -func String() *SchemaBuilder { - return NewSchema("string") + return w.fn(ctx, input, params) } -// Number creates a schema builder for a number type. -func Number() *SchemaBuilder { - return NewSchema("number") -} +// schemaToParameters converts a Schema to the parameters map format expected by ToolInfo. +func schemaToParameters(schema Schema) map[string]any { + if schema.Type != "object" || schema.Properties == nil { + return map[string]any{} + } -// Array creates a schema builder for an array type. -func Array() *SchemaBuilder { - return NewSchema("array") -} + params := make(map[string]any) + for name, propSchema := range schema.Properties { + param := map[string]any{ + "type": propSchema.Type, + } -// Description sets the schema description. -func (b *SchemaBuilder) Description(desc string) *SchemaBuilder { - b.schema.Description = desc - return b -} + if propSchema.Description != "" { + param["description"] = propSchema.Description + } -// Properties sets the schema properties. -func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder { - b.schema.Properties = props - return b -} + if len(propSchema.Enum) > 0 { + param["enum"] = propSchema.Enum + } -// Property adds a property to the schema. -func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder { - if b.schema.Properties == nil { - b.schema.Properties = make(map[string]*Schema) - } - b.schema.Properties[name] = schema - return b -} + if propSchema.Format != "" { + param["format"] = propSchema.Format + } -// Required marks fields as required. -func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder { - b.schema.Required = append(b.schema.Required, fields...) - return b -} + if propSchema.Minimum != nil { + param["minimum"] = *propSchema.Minimum + } -// Items sets the schema for array items. -func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder { - b.schema.Items = schema - return b -} + if propSchema.Maximum != nil { + param["maximum"] = *propSchema.Maximum + } -// Enum sets allowed values for the schema. -func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder { - b.schema.Enum = values - return b -} + if propSchema.MinLength != nil { + param["minLength"] = *propSchema.MinLength + } -// Format sets the string format. -func (b *SchemaBuilder) Format(format string) *SchemaBuilder { - b.schema.Format = format - return b -} + if propSchema.MaxLength != nil { + param["maxLength"] = *propSchema.MaxLength + } -// Min sets the minimum value. -func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder { - b.schema.Minimum = &minimum - return b -} + if propSchema.Items != nil { + param["items"] = schemaToParameters(*propSchema.Items) + } -// Max sets the maximum value. -func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder { - b.schema.Maximum = &maximum - return b + params[name] = param + } + + return params } -// MinLength sets the minimum string length. -func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder { - b.schema.MinLength = &minimum - return b +// generateSchema automatically generates a JSON schema from a Go type. +func generateSchema(t reflect.Type) Schema { + return generateSchemaRecursive(t, make(map[reflect.Type]bool)) } -// MaxLength sets the maximum string length. -func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder { - b.schema.MaxLength = &maximum - return b +func generateSchemaRecursive(t reflect.Type, visited map[reflect.Type]bool) Schema { + // Handle pointers + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + + // Prevent infinite recursion + if visited[t] { + return Schema{Type: "object"} + } + visited[t] = true + defer delete(visited, t) + + switch t.Kind() { + case reflect.String: + return Schema{Type: "string"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return Schema{Type: "integer"} + case reflect.Float32, reflect.Float64: + return Schema{Type: "number"} + case reflect.Bool: + return Schema{Type: "boolean"} + case reflect.Slice, reflect.Array: + itemSchema := generateSchemaRecursive(t.Elem(), visited) + return Schema{ + Type: "array", + Items: &itemSchema, + } + case reflect.Map: + if t.Key().Kind() == reflect.String { + valueSchema := generateSchemaRecursive(t.Elem(), visited) + return Schema{ + Type: "object", + Properties: map[string]*Schema{ + "*": &valueSchema, + }, + } + } + return Schema{Type: "object"} + case reflect.Struct: + schema := Schema{ + Type: "object", + Properties: make(map[string]*Schema), + } + + for i := range t.NumField() { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + fieldName := field.Name + required := true + + // Parse JSON tag + if jsonTag != "" { + parts := strings.Split(jsonTag, ",") + if parts[0] != "" { + fieldName = parts[0] + } + + // Check for omitempty + for _, part := range parts[1:] { + if part == "omitempty" { + required = false + break + } + } + } else { + // Convert field name to snake_case for JSON + fieldName = toSnakeCase(fieldName) + } + + fieldSchema := generateSchemaRecursive(field.Type, visited) + + // Add description from struct tag if available + if desc := field.Tag.Get("description"); desc != "" { + fieldSchema.Description = desc + } + + // Add enum values from struct tag if available + if enumTag := field.Tag.Get("enum"); enumTag != "" { + enumValues := strings.Split(enumTag, ",") + fieldSchema.Enum = make([]any, len(enumValues)) + for i, v := range enumValues { + fieldSchema.Enum[i] = strings.TrimSpace(v) + } + } + + schema.Properties[fieldName] = &fieldSchema + + if required { + schema.Required = append(schema.Required, fieldName) + } + } + + return schema + case reflect.Interface: + return Schema{Type: "object"} + default: + return Schema{Type: "object"} + } } -// Build creates the final schema. -func (b *SchemaBuilder) Build() *Schema { - return &b.schema +// toSnakeCase converts PascalCase to snake_case. +func toSnakeCase(s string) string { + var result strings.Builder + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result.WriteByte('_') + } + result.WriteRune(r) + } + return strings.ToLower(result.String()) } diff --git a/tool_test.go b/tool_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ea1ab76673eb018d416175f716b0b301c63889c8 --- /dev/null +++ b/tool_test.go @@ -0,0 +1,211 @@ +package ai + +import ( + "context" + "fmt" + "reflect" + "strings" + "testing" +) + +// Example of a simple typed tool using the function approach +type CalculatorInput struct { + Expression string `json:"expression" description:"Mathematical expression to evaluate"` +} + +func TestTypedToolFuncExample(t *testing.T) { + // Create a typed tool using the function API + tool := NewTypedToolFunc( + "calculator", + "Evaluates simple mathematical expressions", + func(ctx context.Context, input CalculatorInput, _ ToolCall) (ToolResponse, error) { + if input.Expression == "2+2" { + return NewTextResponse("4"), nil + } + return NewTextErrorResponse("unsupported expression"), nil + }, + ) + + // Check the tool info + info := tool.Info() + if info.Name != "calculator" { + t.Errorf("Expected tool name 'calculator', got %s", info.Name) + } + if len(info.Required) != 1 || info.Required[0] != "expression" { + t.Errorf("Expected required field 'expression', got %v", info.Required) + } + + // Test execution + call := ToolCall{ + ID: "test-1", + Name: "calculator", + Input: `{"expression": "2+2"}`, + } + + result, err := tool.Run(context.Background(), call) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result.Content != "4" { + t.Errorf("Expected result '4', got %s", result.Content) + } + if result.IsError { + t.Errorf("Expected successful result, got error") + } +} + +func TestEnumToolExample(t *testing.T) { + type WeatherInput struct { + Location string `json:"location" description:"City name"` + Units string `json:"units" enum:"celsius,fahrenheit" description:"Temperature units"` + } + + // Create a weather tool with enum support + tool := NewTypedToolFunc( + "weather", + "Gets current weather for a location", + func(ctx context.Context, input WeatherInput, _ ToolCall) (ToolResponse, error) { + temp := "22°C" + if input.Units == "fahrenheit" { + temp = "72°F" + } + return NewTextResponse(fmt.Sprintf("Weather in %s: %s, sunny", input.Location, temp)), nil + }, + ) + + // Check that the schema includes enum values + info := tool.Info() + unitsParam, ok := info.Parameters["units"].(map[string]any) + if !ok { + t.Fatal("Expected units parameter to exist") + } + enumValues, ok := unitsParam["enum"].([]any) + if !ok || len(enumValues) != 2 { + t.Errorf("Expected 2 enum values, got %v", enumValues) + } + + // Test execution with enum value + call := ToolCall{ + ID: "test-2", + Name: "weather", + Input: `{"location": "San Francisco", "units": "fahrenheit"}`, + } + + result, err := tool.Run(context.Background(), call) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !strings.Contains(result.Content, "San Francisco") { + t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content) + } + if !strings.Contains(result.Content, "72°F") { + t.Errorf("Expected result to contain '72°F', got %s", result.Content) + } +} + +func TestEnumSupport(t *testing.T) { + // Test enum via struct tags + type WeatherInput struct { + Location string `json:"location" description:"City name"` + Units string `json:"units" enum:"celsius,fahrenheit,kelvin" description:"Temperature units"` + Format string `json:"format,omitempty" enum:"json,xml,text"` + } + + schema := generateSchema(reflect.TypeOf(WeatherInput{})) + + if schema.Type != "object" { + t.Errorf("Expected schema type 'object', got %s", schema.Type) + } + + // Check units field has enum values + unitsSchema := schema.Properties["units"] + if unitsSchema == nil { + t.Fatal("Expected units property to exist") + } + if len(unitsSchema.Enum) != 3 { + t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum)) + } + expectedUnits := []string{"celsius", "fahrenheit", "kelvin"} + for i, expected := range expectedUnits { + if unitsSchema.Enum[i] != expected { + t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i]) + } + } + + // Check required fields (format should not be required due to omitempty) + expectedRequired := []string{"location", "units"} + if len(schema.Required) != len(expectedRequired) { + t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required)) + } +} + +func TestSchemaToParameters(t *testing.T) { + schema := Schema{ + Type: "object", + Properties: map[string]*Schema{ + "name": { + Type: "string", + Description: "The name field", + }, + "age": { + Type: "integer", + Minimum: func() *float64 { v := 0.0; return &v }(), + Maximum: func() *float64 { v := 120.0; return &v }(), + }, + "tags": { + Type: "array", + Items: &Schema{ + Type: "string", + }, + }, + "priority": { + Type: "string", + Enum: []any{"low", "medium", "high"}, + }, + }, + Required: []string{"name"}, + } + + params := schemaToParameters(schema) + + // Check name parameter + nameParam, ok := params["name"].(map[string]any) + if !ok { + t.Fatal("Expected name parameter to exist") + } + if nameParam["type"] != "string" { + t.Errorf("Expected name type 'string', got %v", nameParam["type"]) + } + if nameParam["description"] != "The name field" { + t.Errorf("Expected name description 'The name field', got %v", nameParam["description"]) + } + + // Check age parameter with min/max + ageParam, ok := params["age"].(map[string]any) + if !ok { + t.Fatal("Expected age parameter to exist") + } + if ageParam["type"] != "integer" { + t.Errorf("Expected age type 'integer', got %v", ageParam["type"]) + } + if ageParam["minimum"] != 0.0 { + t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"]) + } + if ageParam["maximum"] != 120.0 { + t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"]) + } + + // Check priority parameter with enum + priorityParam, ok := params["priority"].(map[string]any) + if !ok { + t.Fatal("Expected priority parameter to exist") + } + if priorityParam["type"] != "string" { + t.Errorf("Expected priority type 'string', got %v", priorityParam["type"]) + } + enumValues, ok := priorityParam["enum"].([]any) + if !ok || len(enumValues) != 3 { + t.Errorf("Expected 3 enum values, got %v", enumValues) + } +} + diff --git a/tools/common.go b/tools/common.go new file mode 100644 index 0000000000000000000000000000000000000000..177469b4ac2c62f1e1d522e061c643f8d47b57ab --- /dev/null +++ b/tools/common.go @@ -0,0 +1,15 @@ +package tools + +import "errors" + +type Permission struct { + ToolCallID string + ToolName string + Path string + Action string + Description string + Params any +} +type PermissionAsk = func(ask Permission) bool + +var ErrorPermissionDenied = errors.New("permission denied") diff --git a/tools/ls.go b/tools/ls.go new file mode 100644 index 0000000000000000000000000000000000000000..ebea64fa450c69e4e008d597104d8541f842516d --- /dev/null +++ b/tools/ls.go @@ -0,0 +1,258 @@ +package tools + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/crush/internal/ai" + "github.com/charmbracelet/crush/internal/fsext" +) + +const ( + MaxLSFiles = 1000 + LSToolName = "ls" +) + +type LSParams struct { + Path string `json:"path" description:"The path to the directory to list (defaults to current working directory)"` + Ignore []string `json:"ignore,omitempty" description:"List of glob patterns to ignore"` +} + +type LSPermissionsParams struct { + Path string `json:"path"` + Ignore []string `json:"ignore"` +} + +type TreeNode struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + Children []*TreeNode `json:"children,omitempty"` +} + +type LSResponseMetadata struct { + NumberOfFiles int `json:"number_of_files"` + Truncated bool `json:"truncated"` +} + +func NewLSTool(permissionAsk PermissionAsk, workingDir string) ai.AgentTool { + return ai.NewTypedToolFunc( + LSToolName, + `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization. + +WHEN TO USE THIS TOOL: +- Use when you need to explore the structure of a directory +- Helpful for understanding the organization of a project +- Good first step when getting familiar with a new codebase + +HOW TO USE: +- Provide a path to list (defaults to current working directory) +- Optionally specify glob patterns to ignore +- Results are displayed in a tree structure + +FEATURES: +- Displays a hierarchical view of files and directories +- Automatically skips hidden files/directories (starting with '.') +- Skips common system directories like __pycache__ +- Can filter out files matching specific patterns + +LIMITATIONS: +- Results are limited to 1000 files +- Very large directories will be truncated +- Does not show file sizes or permissions +- Cannot recursively list all directories in a large project + +WINDOWS NOTES: +- Hidden file detection uses Unix convention (files starting with '.') +- Windows-specific hidden files (with hidden attribute) are not automatically skipped +- Common Windows directories like System32, Program Files are not in default ignore list +- Path separators are handled automatically (both / and \ work) + +TIPS: +- Use Glob tool for finding files by name patterns instead of browsing +- Use Grep tool for searching file contents +- Combine with other tools for more effective exploration`, + func(ctx context.Context, params LSParams, call ai.ToolCall) (ai.ToolResponse, error) { + searchPath := params.Path + if searchPath == "" { + searchPath = workingDir + } + + var err error + searchPath, err = fsext.Expand(searchPath) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error expanding path: %w", err) + } + + if !filepath.IsAbs(searchPath) { + searchPath = filepath.Join(workingDir, searchPath) + } + + // Check if directory is outside working directory and request permission if needed + absWorkingDir, err := filepath.Abs(workingDir) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err) + } + + absSearchPath, err := filepath.Abs(searchPath) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error resolving search path: %w", err) + } + + relPath, err := filepath.Rel(absWorkingDir, absSearchPath) + if err != nil || strings.HasPrefix(relPath, "..") { + granted := permissionAsk(Permission{ + ToolCallID: call.ID, + ToolName: LSToolName, + Path: absSearchPath, + Action: "list", + Description: fmt.Sprintf("List directory outside working directory: %s", absSearchPath), + Params: LSPermissionsParams(params), + }) + + if !granted { + return ai.ToolResponse{}, ErrorPermissionDenied + } + } + output, err := ListDirectoryTree(searchPath, params.Ignore) + if err != nil { + return ai.ToolResponse{}, err + } + + // Get file count for metadata + files, truncated, err := fsext.ListDirectory(searchPath, params.Ignore, MaxLSFiles) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error listing directory for metadata: %w", err) + } + + return ai.WithResponseMetadata( + ai.NewTextResponse(output), + LSResponseMetadata{ + NumberOfFiles: len(files), + Truncated: truncated, + }, + ), nil + }, + ) +} + +func ListDirectoryTree(searchPath string, ignore []string) (string, error) { + if _, err := os.Stat(searchPath); os.IsNotExist(err) { + return "", fmt.Errorf("path does not exist: %s", searchPath) + } + + files, truncated, err := fsext.ListDirectory(searchPath, ignore, MaxLSFiles) + if err != nil { + return "", fmt.Errorf("error listing directory: %w", err) + } + + tree := createFileTree(files, searchPath) + output := printTree(tree, searchPath) + + if truncated { + output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output) + } + + return output, nil +} + +func createFileTree(sortedPaths []string, rootPath string) []*TreeNode { + root := []*TreeNode{} + pathMap := make(map[string]*TreeNode) + + for _, path := range sortedPaths { + relativePath := strings.TrimPrefix(path, rootPath) + parts := strings.Split(relativePath, string(filepath.Separator)) + currentPath := "" + var parentPath string + + var cleanParts []string + for _, part := range parts { + if part != "" { + cleanParts = append(cleanParts, part) + } + } + parts = cleanParts + + if len(parts) == 0 { + continue + } + + for i, part := range parts { + if currentPath == "" { + currentPath = part + } else { + currentPath = filepath.Join(currentPath, part) + } + + if _, exists := pathMap[currentPath]; exists { + parentPath = currentPath + continue + } + + isLastPart := i == len(parts)-1 + isDir := !isLastPart || strings.HasSuffix(relativePath, string(filepath.Separator)) + nodeType := "file" + if isDir { + nodeType = "directory" + } + newNode := &TreeNode{ + Name: part, + Path: currentPath, + Type: nodeType, + Children: []*TreeNode{}, + } + + pathMap[currentPath] = newNode + + if i > 0 && parentPath != "" { + if parent, ok := pathMap[parentPath]; ok { + parent.Children = append(parent.Children, newNode) + } + } else { + root = append(root, newNode) + } + + parentPath = currentPath + } + } + + return root +} + +func printTree(tree []*TreeNode, rootPath string) string { + var result strings.Builder + + result.WriteString("- ") + result.WriteString(rootPath) + if rootPath[len(rootPath)-1] != '/' { + result.WriteByte(filepath.Separator) + } + result.WriteByte('\n') + + for _, node := range tree { + printNode(&result, node, 1) + } + + return result.String() +} + +func printNode(builder *strings.Builder, node *TreeNode, level int) { + indent := strings.Repeat(" ", level) + + nodeName := node.Name + if node.Type == "directory" { + nodeName = nodeName + string(filepath.Separator) + } + + fmt.Fprintf(builder, "%s- %s\n", indent, nodeName) + + if node.Type == "directory" && len(node.Children) > 0 { + for _, child := range node.Children { + printNode(builder, child, level+1) + } + } +}