diff --git a/cmd/root.go b/cmd/root.go index 3ed809328f9670be781a8651737f5321ada8e340..c4e99985ac8e77f5b6eefac215181201c1508324 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -50,7 +50,7 @@ var rootCmd = &cobra.Command{ go func() { // Set this up once - agent.GetMcpTools(ctx) + agent.GetMcpTools(ctx, app.Permissions) for msg := range ch { tui.Send(msg) } diff --git a/internal/app/services.go b/internal/app/services.go index 60838cccaf548f0c8a252acab34ee8baedb5a6f8..668da9a1d454a5ec68b8535248487f0b42fb3491 100644 --- a/internal/app/services.go +++ b/internal/app/services.go @@ -41,7 +41,7 @@ func New(ctx context.Context, conn *sql.DB) *App { Context: ctx, Sessions: sessions, Messages: messages, - Permissions: permission.Default, + Permissions: permission.NewPermissionService(), Logger: log, LSPClients: make(map[string]*lsp.Client), } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index b47289c334be8eb6b8784fde42d6519b71502474..8ff3c61aa8ba49b3f8f0447049ba23b33a056459 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -44,7 +44,7 @@ func NewCoderAgent(app *app.App) (Agent, error) { return nil, err } - otherTools := GetMcpTools(app.Context) + otherTools := GetMcpTools(app.Context, app.Permissions) if len(app.LSPClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) } @@ -53,15 +53,15 @@ func NewCoderAgent(app *app.App) (Agent, error) { App: app, tools: append( []tools.BaseTool{ - tools.NewBashTool(), - tools.NewEditTool(app.LSPClients), - tools.NewFetchTool(), + tools.NewBashTool(app.Permissions), + tools.NewEditTool(app.LSPClients, app.Permissions), + tools.NewFetchTool(app.Permissions), tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), tools.NewSourcegraphTool(), tools.NewViewTool(app.LSPClients), - tools.NewWriteTool(app.LSPClients), + tools.NewWriteTool(app.LSPClients, app.Permissions), }, otherTools..., ), model: model, diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index ec78e0a86cf36cdeebd4d1538512147e846466f3..64b5f639bcc5bd9ca87000861b62b0d69628d38b 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -16,9 +16,10 @@ import ( ) type mcpTool struct { - mcpName string - tool mcp.Tool - mcpConfig config.MCPServer + mcpName string + tool mcp.Tool + mcpConfig config.MCPServer + permissions permission.Service } type MCPClient interface { @@ -80,7 +81,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input) - p := permission.Default.Request( + p := b.permissions.Request( permission.CreatePermissionRequest{ Path: config.WorkingDirectory(), ToolName: b.Info().Name, @@ -118,17 +119,18 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, mcpConfig config.MCPServer) tools.BaseTool { +func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool { return &mcpTool{ - mcpName: name, - tool: tool, - mcpConfig: mcpConfig, + mcpName: name, + tool: tool, + mcpConfig: mcpConfig, + permissions: permissions, } } var mcpTools []tools.BaseTool -func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) []tools.BaseTool { +func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION @@ -149,13 +151,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) return stdioTools } for _, t := range tools.Tools { - stdioTools = append(stdioTools, NewMcpTool(name, t, m)) + stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m)) } defer c.Close() return stdioTools } -func GetMcpTools(ctx context.Context) []tools.BaseTool { +func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool { if len(mcpTools) > 0 { return mcpTools } @@ -172,7 +174,7 @@ func GetMcpTools(ctx context.Context) []tools.BaseTool { continue } - mcpTools = append(mcpTools, getTools(ctx, name, m, c)...) + mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) case config.MCPSse: c, err := client.NewSSEMCPClient( m.URL, @@ -182,7 +184,7 @@ func GetMcpTools(ctx context.Context) []tools.BaseTool { log.Printf("error creating mcp client: %s", err) continue } - mcpTools = append(mcpTools, getTools(ctx, name, m, c)...) + mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) } } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index a78c03215c3dee7381e84970855f93b5c29147d2..4e80ae60a3e4de34da6a800f339be22ab2439785 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -11,16 +11,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/permission" ) -type bashTool struct{} - -const ( - BashToolName = "bash" - - DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds - MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds - MaxOutputLength = 30000 -) - type BashParams struct { Command string `json:"command"` Timeout int `json:"timeout"` @@ -31,180 +21,36 @@ type BashPermissionsParams struct { Timeout int `json:"timeout"` } -var BannedCommands = []string{ +type bashTool struct { + permissions permission.Service +} + +const ( + BashToolName = "bash" + + DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds + MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds + MaxOutputLength = 30000 +) + +var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", "http-prompt", "chrome", "firefox", "safari", } -var SafeReadOnlyCommands = []string{ - // Basic shell commands +var safeReadOnlyCommands = []string{ "ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis", "whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout", - - // Git read-only commands + "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", - - // Go commands - "go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", - - // Node.js commands - "node", "npm", "npx", "yarn", "pnpm", - - // Python commands - "python", "python3", "pip", "pip3", "pytest", "pylint", "mypy", "black", "isort", "flake8", "ruff", - - // Docker commands - "docker ps", "docker images", "docker volume", "docker network", "docker info", "docker version", - "docker-compose ps", "docker-compose config", - - // Kubernetes commands - "kubectl get", "kubectl describe", "kubectl logs", "kubectl version", "kubectl config", - - // Rust commands - "cargo", "rustc", "rustup", - - // Java commands - "java", "javac", "mvn", "gradle", - - // Misc development tools - "make", "cmake", "bazel", "terraform plan", "terraform validate", "ansible", -} - -func (b *bashTool) Info() ToolInfo { - return ToolInfo{ - Name: BashToolName, - Description: bashDescription(), - Parameters: map[string]any{ - "command": map[string]any{ - "type": "string", - "description": "The command to execute", - }, - "timeout": map[string]any{ - "type": "number", - "desription": "Optional timeout in milliseconds (max 600000)", - }, - }, - Required: []string{"command"}, - } -} - -// Handle implements Tool. -func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params BashParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("invalid parameters"), nil - } - - if params.Timeout > MaxTimeout { - params.Timeout = MaxTimeout - } else if params.Timeout <= 0 { - params.Timeout = DefaultTimeout - } - - if params.Command == "" { - return NewTextErrorResponse("missing command"), nil - } - - // Check for banned commands (first word only) - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range BannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - - // Check for safe commands (can be multi-word) - isSafeReadOnly := false - cmdLower := strings.ToLower(params.Command) - - for _, safe := range SafeReadOnlyCommands { - // Check if command starts with the safe command pattern - if strings.HasPrefix(cmdLower, strings.ToLower(safe)) { - // Make sure it's either an exact match or followed by a space or flag - if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { - isSafeReadOnly = true - break - } - } - } - if !isSafeReadOnly { - p := permission.Default.Request( - permission.CreatePermissionRequest{ - Path: config.WorkingDirectory(), - ToolName: BashToolName, - Action: "execute", - Description: fmt.Sprintf("Execute command: %s", params.Command), - Params: BashPermissionsParams{ - Command: params.Command, - }, - }, - ) - if !p { - return NewTextErrorResponse("permission denied"), nil - } - } - shell := shell.GetPersistentShell(config.WorkingDirectory()) - stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout) - if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil - } - - stdout = truncateOutput(stdout) - stderr = truncateOutput(stderr) - - errorMessage := stderr - if interrupted { - if errorMessage != "" { - errorMessage += "\n" - } - errorMessage += "Command was aborted before completion" - } else if exitCode != 0 { - if errorMessage != "" { - errorMessage += "\n" - } - errorMessage += fmt.Sprintf("Exit code %d", exitCode) - } - - hasBothOutputs := stdout != "" && stderr != "" - - if hasBothOutputs { - stdout += "\n" - } - - if errorMessage != "" { - stdout += "\n" + errorMessage - } - - if stdout == "" { - return NewTextResponse("no output"), nil - } - return NewTextResponse(stdout), nil -} - -func truncateOutput(content string) string { - if len(content) <= MaxOutputLength { - return content - } - - halfLength := MaxOutputLength / 2 - start := content[:halfLength] - end := content[len(content)-halfLength:] - - truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength]) - return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end) -} -func countLines(s string) int { - if s == "" { - return 0 - } - return len(strings.Split(s, "\n")) + "go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", } func bashDescription() string { - bannedCommandsStr := strings.Join(BannedCommands, ", ") + bannedCommandsStr := strings.Join(bannedCommands, ", ") return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures. Before executing the command, please follow these steps: @@ -352,6 +198,134 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } -func NewBashTool() BaseTool { - return &bashTool{} +func NewBashTool(permission permission.Service) BaseTool { + return &bashTool{ + permissions: permission, + } +} + +func (b *bashTool) Info() ToolInfo { + return ToolInfo{ + Name: BashToolName, + Description: bashDescription(), + Parameters: map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "The command to execute", + }, + "timeout": map[string]any{ + "type": "number", + "description": "Optional timeout in milliseconds (max 600000)", + }, + }, + Required: []string{"command"}, + } +} + +func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params BashParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse("invalid parameters"), nil + } + + if params.Timeout > MaxTimeout { + params.Timeout = MaxTimeout + } else if params.Timeout <= 0 { + params.Timeout = DefaultTimeout + } + + if params.Command == "" { + return NewTextErrorResponse("missing command"), nil + } + + baseCmd := strings.Fields(params.Command)[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil + } + } + + isSafeReadOnly := false + cmdLower := strings.ToLower(params.Command) + + for _, safe := range safeReadOnlyCommands { + if strings.HasPrefix(cmdLower, strings.ToLower(safe)) { + if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { + isSafeReadOnly = true + break + } + } + } + if !isSafeReadOnly { + p := b.permissions.Request( + permission.CreatePermissionRequest{ + Path: config.WorkingDirectory(), + ToolName: BashToolName, + Action: "execute", + Description: fmt.Sprintf("Execute command: %s", params.Command), + Params: BashPermissionsParams{ + Command: params.Command, + }, + }, + ) + if !p { + return NewTextErrorResponse("permission denied"), nil + } + } + shell := shell.GetPersistentShell(config.WorkingDirectory()) + stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil + } + + stdout = truncateOutput(stdout) + stderr = truncateOutput(stderr) + + errorMessage := stderr + if interrupted { + if errorMessage != "" { + errorMessage += "\n" + } + errorMessage += "Command was aborted before completion" + } else if exitCode != 0 { + if errorMessage != "" { + errorMessage += "\n" + } + errorMessage += fmt.Sprintf("Exit code %d", exitCode) + } + + hasBothOutputs := stdout != "" && stderr != "" + + if hasBothOutputs { + stdout += "\n" + } + + if errorMessage != "" { + stdout += "\n" + errorMessage + } + + if stdout == "" { + return NewTextResponse("no output"), nil + } + return NewTextResponse(stdout), nil +} + +func truncateOutput(content string) string { + if len(content) <= MaxOutputLength { + return content + } + + halfLength := MaxOutputLength / 2 + start := content[:halfLength] + end := content[len(content)-halfLength:] + + truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength]) + return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end) +} + +func countLines(s string) int { + if s == "" { + return 0 + } + return len(strings.Split(s, "\n")) } diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 9eadc227ce12edcca692b68ea412050afaaa06fd..97be3683aa243f9c9b74fc84503e973959ac9e8a 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -15,7 +15,7 @@ import ( ) func TestBashTool_Info(t *testing.T) { - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) info := tool.Info() assert.Equal(t, BashToolName, info.Name) @@ -26,13 +26,6 @@ func TestBashTool_Info(t *testing.T) { } func TestBashTool_Run(t *testing.T) { - // Setup a mock permission handler that always allows - origPermission := permission.Default - defer func() { - permission.Default = origPermission - }() - permission.Default = newMockPermissionService(true) - // Save original working directory origWd, err := os.Getwd() require.NoError(t, err) @@ -41,8 +34,7 @@ func TestBashTool_Run(t *testing.T) { }() t.Run("executes command successfully", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "echo 'Hello World'", } @@ -61,9 +53,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) call := ToolCall{ Name: BashToolName, Input: "invalid json", @@ -75,9 +65,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles missing command", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "", } @@ -96,11 +84,9 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles banned commands", func(t *testing.T) { - permission.Default = newMockPermissionService(true) + tool := NewBashTool(newMockPermissionService(true)) - tool := NewBashTool() - - for _, bannedCmd := range BannedCommands { + for _, bannedCmd := range bannedCommands { params := BashParams{ Command: bannedCmd + " arg1 arg2", } @@ -120,17 +106,11 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles multi-word safe commands without permission check", func(t *testing.T) { - permission.Default = newMockPermissionService(false) - - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(false)) // Test with multi-word safe commands multiWordCommands := []string{ - "git status", - "git log -n 5", - "docker ps", - "go test ./...", - "kubectl get pods", + "go env", } for _, cmd := range multiWordCommands { @@ -148,15 +128,13 @@ func TestBashTool_Run(t *testing.T) { response, err := tool.Run(context.Background(), call) require.NoError(t, err) - assert.NotContains(t, response.Content, "permission denied", + assert.NotContains(t, response.Content, "permission denied", "Command %s should be allowed without permission", cmd) } }) t.Run("handles permission denied", func(t *testing.T) { - permission.Default = newMockPermissionService(false) - - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(false)) // Test with a command that requires permission params := BashParams{ @@ -177,8 +155,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles command timeout", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "sleep 2", Timeout: 100, // 100ms timeout @@ -198,8 +175,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles command with stderr output", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "echo 'error message' >&2", } @@ -218,8 +194,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles command with both stdout and stderr", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "echo 'stdout message' && echo 'stderr message' >&2", } @@ -239,8 +214,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("handles context cancellation", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "sleep 5", } @@ -267,8 +241,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("respects max timeout", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "echo 'test'", Timeout: MaxTimeout + 1000, // Exceeds max timeout @@ -288,8 +261,7 @@ func TestBashTool_Run(t *testing.T) { }) t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewBashTool() + tool := NewBashTool(newMockPermissionService(true)) params := BashParams{ Command: "echo 'test'", Timeout: -100, // Negative timeout @@ -397,4 +369,3 @@ func newMockPermissionService(allow bool) permission.Service { allow: allow, } } - diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index ce76ae12fec2e74a0e3140d0eeb88357cbc34de7..1bb02098e131ba1c2f193954de7246cca0a2fd7c 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -13,22 +13,48 @@ import ( "github.com/kujtimiihoxha/termai/internal/lsp/protocol" ) +type DiagnosticsParams struct { + FilePath string `json:"file_path"` +} type diagnosticsTool struct { lspClients map[string]*lsp.Client } const ( - DiagnosticsToolName = "diagnostics" + DiagnosticsToolName = "diagnostics" + diagnosticsDescription = `Get diagnostics for a file and/or project. +WHEN TO USE THIS TOOL: +- Use when you need to check for errors or warnings in your code +- Helpful for debugging and ensuring code quality +- Good for getting a quick overview of issues in a file or project +HOW TO USE: +- Provide a path to a file to get diagnostics for that file +- Leave the path empty to get diagnostics for the entire project +- Results are displayed in a structured format with severity levels +FEATURES: +- Displays errors, warnings, and hints +- Groups diagnostics by severity +- Provides detailed information about each diagnostic +LIMITATIONS: +- Results are limited to the diagnostics provided by the LSP clients +- May not cover all possible issues in the code +- Does not provide suggestions for fixing issues +TIPS: +- Use in conjunction with other tools for a comprehensive code review +- Combine with the LSP client for real-time diagnostics +` ) -type DiagnosticsParams struct { - FilePath string `json:"file_path"` +func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { + return &diagnosticsTool{ + lspClients, + } } func (b *diagnosticsTool) Info() ToolInfo { return ToolInfo{ Name: DiagnosticsToolName, - Description: "Get diagnostics for a file and/or project.", + Description: diagnosticsDescription, Parameters: map[string]any{ "file_path": map[string]any{ "type": "string", @@ -63,31 +89,24 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) { for _, client := range lsps { - // Open the file err := client.OpenFile(ctx, filePath) if err != nil { - // If there's an error opening the file, continue to the next client continue } } } -// waitForLspDiagnostics opens a file in LSP clients and waits for diagnostics to be published func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) { if len(lsps) == 0 { return } - // Create a channel to receive diagnostic notifications diagChan := make(chan struct{}, 1) - // Register a temporary diagnostic handler for each client for _, client := range lsps { - // Store the original diagnostics map to detect changes originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic) maps.Copy(originalDiags, client.GetDiagnostics()) - // Create a notification handler that will signal when diagnostics are received handler := func(params json.RawMessage) { lsp.HandleDiagnostics(client, params) var diagParams protocol.PublishDiagnosticsParams @@ -95,28 +114,22 @@ func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string return } - // If this is for our file or we've received any new diagnostics, signal completion if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) { select { case diagChan <- struct{}{}: - // Signal sent default: - // Channel already has a value, no need to send again } } } - // Register our temporary handler client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler) - // Notify change if the file is already open if client.IsFileOpen(filePath) { err := client.NotifyChange(ctx, filePath) if err != nil { continue } } else { - // Open the file if it's not already open err := client.OpenFile(ctx, filePath) if err != nil { continue @@ -124,22 +137,13 @@ func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string } } - // Wait for diagnostics with a reasonable timeout select { case <-diagChan: - // Diagnostics received case <-time.After(5 * time.Second): - // Timeout after 5 seconds - this is a fallback in case no diagnostics are published case <-ctx.Done(): - // Context cancelled } - - // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler - // replaces any existing handler, and we'll be replaced by the original handler when - // the LSP client is reinitialized or when a new handler is registered. } -// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool { for uri, diags := range current { origDiags, exists := original[uri] @@ -154,9 +158,7 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { fileDiagnostics := []string{} projectDiagnostics := []string{} - // Enhanced format function that includes more diagnostic information formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string { - // Base components severity := "Info" switch diagnostic.Severity { case protocol.SeverityError: @@ -167,10 +169,8 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { severity = "Hint" } - // Location information location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1) - // Source information (LSP name) sourceInfo := "" if diagnostic.Source != "" { sourceInfo = diagnostic.Source @@ -178,13 +178,11 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { sourceInfo = source } - // Code information codeInfo := "" if diagnostic.Code != nil { codeInfo = fmt.Sprintf("[%v]", diagnostic.Code) } - // Tags information tagsInfo := "" if len(diagnostic.Tags) > 0 { tags := []string{} @@ -201,7 +199,6 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { } } - // Assemble the full diagnostic message return fmt.Sprintf("%s: %s [%s]%s%s %s", severity, location, @@ -217,7 +214,6 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { for location, diags := range diagnostics { isCurrentFile := location.Path() == filePath - // Group diagnostics by severity for better organization for _, diag := range diags { formattedDiag := formatDiagnostic(location.Path(), diag, lspName) @@ -231,7 +227,6 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { } } - // Sort diagnostics by severity (errors first) and then by location sort.Slice(fileDiagnostics, func(i, j int) bool { iIsError := strings.HasPrefix(fileDiagnostics[i], "Error") jIsError := strings.HasPrefix(fileDiagnostics[j], "Error") @@ -274,7 +269,6 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { output += "\n\n" } - // Add summary counts if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 { fileErrors := countSeverity(fileDiagnostics, "Error") fileWarnings := countSeverity(fileDiagnostics, "Warn") @@ -290,7 +284,6 @@ func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { return output } -// Helper function to count diagnostics by severity func countSeverity(diagnostics []string, severity string) int { count := 0 for _, diag := range diagnostics { @@ -300,9 +293,3 @@ func countSeverity(diagnostics []string, severity string) int { } return count } - -func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { - return &diagnosticsTool{ - lspClients, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index f158401b8376c6e9a6cc475d6198c0501806c777..32e2034e451063f0f5e8e33328793c0f44bca4d1 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -15,14 +15,6 @@ import ( "github.com/sergi/go-diff/diffmatchpatch" ) -type editTool struct { - lspClients map[string]*lsp.Client -} - -const ( - EditToolName = "edit" -) - type EditParams struct { FilePath string `json:"file_path"` OldString string `json:"old_string"` @@ -36,10 +28,73 @@ type EditPermissionsParams struct { Diff string `json:"diff"` } +type editTool struct { + lspClients map[string]*lsp.Client + permissions permission.Service +} + +const ( + EditToolName = "edit" + editDescription = `Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files. + +Before using this tool: + +1. Use the FileRead tool to understand the file's contents and context + +2. Verify the directory path is correct (only applicable when creating new files): + - Use the LS tool to verify the parent directory exists and is the correct location + +To make a file edit, provide the following: +1. file_path: The absolute path to the file to modify (must be absolute, not relative) +2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) +3. new_string: The edited text to replace the old_string + +Special cases: +- To create a new file: provide file_path and new_string, leave old_string empty +- To delete content: provide file_path and old_string, leave new_string empty + +The tool will replace ONE occurrence of old_string with new_string in the specified file. + +CRITICAL REQUIREMENTS FOR USING THIS TOOL: + +1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means: + - Include AT LEAST 3-5 lines of context BEFORE the change point + - Include AT LEAST 3-5 lines of context AFTER the change point + - Include all whitespace, indentation, and surrounding code exactly as it appears in the file + +2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances: + - Make separate calls to this tool for each instance + - Each call must uniquely identify its specific instance using extensive context + +3. VERIFICATION: Before using this tool: + - Check how many instances of the target text exist in the file + - If multiple instances exist, gather enough context to uniquely identify each one + - Plan separate tool calls for each instance + +WARNING: If you do not follow these requirements: + - The tool will fail if old_string matches multiple locations + - The tool will fail if old_string doesn't match exactly (including whitespace) + - You may change the wrong instance if you don't include enough context + +When making edits: + - Ensure the edit results in idiomatic, correct code + - Do not leave the code in a broken state + - Always use absolute file paths (starting with /) + +Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` +) + +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { + return &editTool{ + lspClients: lspClients, + permissions: permissions, + } +} + func (e *editTool) Info() ToolInfo { return ToolInfo{ Name: EditToolName, - Description: editDescription(), + Description: editDescription, Parameters: map[string]any{ "file_path": map[string]any{ "type": "string", @@ -58,7 +113,6 @@ func (e *editTool) Info() ToolInfo { } } -// Run implements Tool. func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params EditParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -75,7 +129,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if params.OldString == "" { - result, err := createNewFile(params.FilePath, params.NewString) + result, err := e.createNewFile(params.FilePath, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil } @@ -83,26 +137,25 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if params.NewString == "" { - result, err := deleteContent(params.FilePath, params.OldString) + result, err := e.deleteContent(params.FilePath, params.OldString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil } return NewTextResponse(result), nil } - result, err := replaceContent(params.FilePath, params.OldString, params.NewString) + result, err := e.replaceContent(params.FilePath, params.OldString, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } - // Wait for LSP diagnostics after editing the file waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) result = fmt.Sprintf("\n%s\n\n", result) result += appendDiagnostics(params.FilePath, e.lspClients) return NewTextResponse(result), nil } -func createNewFile(filePath, content string) (string, error) { +func (e *editTool) createNewFile(filePath, content string) (string, error) { fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { @@ -118,7 +171,7 @@ func createNewFile(filePath, content string) (string, error) { return "", fmt.Errorf("failed to create parent directories: %w", err) } - p := permission.Default.Request( + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), ToolName: EditToolName, @@ -147,7 +200,7 @@ func createNewFile(filePath, content string) (string, error) { return "File created: " + filePath, nil } -func deleteContent(filePath, oldString string) (string, error) { +func (e *editTool) deleteContent(filePath, oldString string) (string, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { @@ -190,7 +243,7 @@ func deleteContent(filePath, oldString string) (string, error) { newContent := oldContent[:index] + oldContent[index+len(oldString):] - p := permission.Default.Request( + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), ToolName: EditToolName, @@ -219,7 +272,7 @@ func deleteContent(filePath, oldString string) (string, error) { return "Content deleted from file: " + filePath, nil } -func replaceContent(filePath, oldString, newString string) (string, error) { +func (e *editTool) replaceContent(filePath, oldString, newString string) (string, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { @@ -268,7 +321,7 @@ func replaceContent(filePath, oldString, newString string) (string, error) { diff := GenerateDiff(oldContent[startIndex:oldEndIndex], newContent[startIndex:newEndIndex]) - p := permission.Default.Request( + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), ToolName: EditToolName, @@ -305,7 +358,6 @@ func GenerateDiff(oldContent, newContent string) string { diffs = dmp.DiffCleanupSemantic(diffs) buff := strings.Builder{} - // Add a header to make the diff more readable buff.WriteString("Changes:\n") for _, diff := range diffs { @@ -327,10 +379,8 @@ func GenerateDiff(oldContent, newContent string) string { _, _ = buff.WriteString("- " + line + "\n") } case diffmatchpatch.DiffEqual: - // Only show a small context for unchanged text lines := strings.Split(text, "\n") if len(lines) > 3 { - // Show only first and last line of context with a separator if lines[0] != "" { _, _ = buff.WriteString(" " + lines[0] + "\n") } @@ -339,7 +389,6 @@ func GenerateDiff(oldContent, newContent string) string { _, _ = buff.WriteString(" " + lines[len(lines)-1] + "\n") } } else { - // Show all lines for small contexts for _, line := range lines { if line == "" { continue @@ -351,59 +400,3 @@ func GenerateDiff(oldContent, newContent string) string { } return buff.String() } - -func editDescription() string { - return `Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files. - -Before using this tool: - -1. Use the FileRead tool to understand the file's contents and context - -2. Verify the directory path is correct (only applicable when creating new files): - - Use the LS tool to verify the parent directory exists and is the correct location - -To make a file edit, provide the following: -1. file_path: The absolute path to the file to modify (must be absolute, not relative) -2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) -3. new_string: The edited text to replace the old_string - -Special cases: -- To create a new file: provide file_path and new_string, leave old_string empty -- To delete content: provide file_path and old_string, leave new_string empty - -The tool will replace ONE occurrence of old_string with new_string in the specified file. - -CRITICAL REQUIREMENTS FOR USING THIS TOOL: - -1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means: - - Include AT LEAST 3-5 lines of context BEFORE the change point - - Include AT LEAST 3-5 lines of context AFTER the change point - - Include all whitespace, indentation, and surrounding code exactly as it appears in the file - -2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances: - - Make separate calls to this tool for each instance - - Each call must uniquely identify its specific instance using extensive context - -3. VERIFICATION: Before using this tool: - - Check how many instances of the target text exist in the file - - If multiple instances exist, gather enough context to uniquely identify each one - - Plan separate tool calls for each instance - -WARNING: If you do not follow these requirements: - - The tool will fail if old_string matches multiple locations - - The tool will fail if old_string doesn't match exactly (including whitespace) - - You may change the wrong instance if you don't include enough context - -When making edits: - - Ensure the edit results in idiomatic, correct code - - Do not leave the code in a broken state - - Always use absolute file paths (starting with /) - -Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` -} - -func NewEditTool(lspClients map[string]*lsp.Client) BaseTool { - return &editTool{ - lspClients, - } -} diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..dbc6e488f822378432ed41e6cd3d3909651bc3e9 --- /dev/null +++ b/internal/llm/tools/edit_test.go @@ -0,0 +1,509 @@ +package tools + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEditTool_Info(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + info := tool.Info() + + assert.Equal(t, EditToolName, info.Name) + assert.NotEmpty(t, info.Description) + assert.Contains(t, info.Parameters, "file_path") + assert.Contains(t, info.Parameters, "old_string") + assert.Contains(t, info.Parameters, "new_string") + assert.Contains(t, info.Required, "file_path") + assert.Contains(t, info.Required, "old_string") + assert.Contains(t, info.Required, "new_string") +} + +func TestEditTool_Run(t *testing.T) { + // Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "edit_tool_test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + t.Run("creates a new file successfully", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + filePath := filepath.Join(tempDir, "new_file.txt") + content := "This is a test content" + + params := EditParams{ + FilePath: filePath, + OldString: "", + NewString: content, + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "File created") + + // Verify file was created with correct content + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, string(fileContent)) + }) + + t.Run("creates file with nested directories", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") + content := "Content in nested directory" + + params := EditParams{ + FilePath: filePath, + OldString: "", + NewString: content, + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "File created") + + // Verify file was created with correct content + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, string(fileContent)) + }) + + t.Run("fails to create file that already exists", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file first + filePath := filepath.Join(tempDir, "existing_file.txt") + initialContent := "Initial content" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Try to create the same file + params := EditParams{ + FilePath: filePath, + OldString: "", + NewString: "New content", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "file already exists") + }) + + t.Run("fails to create file when path is a directory", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a directory + dirPath := filepath.Join(tempDir, "test_dir") + err := os.Mkdir(dirPath, 0o755) + require.NoError(t, err) + + // Try to create a file with the same path as the directory + params := EditParams{ + FilePath: dirPath, + OldString: "", + NewString: "Some content", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "path is a directory") + }) + + t.Run("replaces content successfully", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file first + filePath := filepath.Join(tempDir, "replace_content.txt") + initialContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record the file read to avoid modification time check failure + recordFileRead(filePath) + + // Replace content + oldString := "Line 2\nLine 3" + newString := "Line 2 modified\nLine 3 modified" + params := EditParams{ + FilePath: filePath, + OldString: oldString, + NewString: newString, + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "Content replaced") + + // Verify file was updated with correct content + expectedContent := "Line 1\nLine 2 modified\nLine 3 modified\nLine 4\nLine 5" + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, expectedContent, string(fileContent)) + }) + + t.Run("deletes content successfully", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file first + filePath := filepath.Join(tempDir, "delete_content.txt") + initialContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record the file read to avoid modification time check failure + recordFileRead(filePath) + + // Delete content + oldString := "Line 2\nLine 3\n" + params := EditParams{ + FilePath: filePath, + OldString: oldString, + NewString: "", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "Content deleted") + + // Verify file was updated with correct content + expectedContent := "Line 1\nLine 4\nLine 5" + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, expectedContent, string(fileContent)) + }) + + t.Run("handles invalid parameters", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + call := ToolCall{ + Name: EditToolName, + Input: "invalid json", + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "invalid parameters") + }) + + t.Run("handles missing file_path", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + params := EditParams{ + FilePath: "", + OldString: "old", + NewString: "new", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "file_path is required") + }) + + t.Run("handles file not found", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + filePath := filepath.Join(tempDir, "non_existent_file.txt") + params := EditParams{ + FilePath: filePath, + OldString: "old content", + NewString: "new content", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "file not found") + }) + + t.Run("handles old_string not found in file", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file first + filePath := filepath.Join(tempDir, "content_not_found.txt") + initialContent := "Line 1\nLine 2\nLine 3" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record the file read to avoid modification time check failure + recordFileRead(filePath) + + // Try to replace content that doesn't exist + params := EditParams{ + FilePath: filePath, + OldString: "This content does not exist", + NewString: "new content", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "old_string not found in file") + }) + + t.Run("handles multiple occurrences of old_string", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file with duplicate content + filePath := filepath.Join(tempDir, "duplicate_content.txt") + initialContent := "Line 1\nDuplicate\nLine 3\nDuplicate\nLine 5" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record the file read to avoid modification time check failure + recordFileRead(filePath) + + // Try to replace content that appears multiple times + params := EditParams{ + FilePath: filePath, + OldString: "Duplicate", + NewString: "Replaced", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "appears multiple times") + }) + + t.Run("handles file modified since last read", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file + filePath := filepath.Join(tempDir, "modified_file.txt") + initialContent := "Initial content" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record an old read time + fileRecordMutex.Lock() + fileRecords[filePath] = fileRecord{ + path: filePath, + readTime: time.Now().Add(-1 * time.Hour), + } + fileRecordMutex.Unlock() + + // Try to update the file + params := EditParams{ + FilePath: filePath, + OldString: "Initial", + NewString: "Updated", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "has been modified since it was last read") + + // Verify file was not modified + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, initialContent, string(fileContent)) + }) + + t.Run("handles file not read before editing", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + + // Create a file + filePath := filepath.Join(tempDir, "not_read_file.txt") + initialContent := "Initial content" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Try to update the file without reading it first + params := EditParams{ + FilePath: filePath, + OldString: "Initial", + NewString: "Updated", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "you must read the file before editing it") + }) + + t.Run("handles permission denied", func(t *testing.T) { + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + + // Create a file + filePath := filepath.Join(tempDir, "permission_denied.txt") + initialContent := "Initial content" + err := os.WriteFile(filePath, []byte(initialContent), 0o644) + require.NoError(t, err) + + // Record the file read to avoid modification time check failure + recordFileRead(filePath) + + // Try to update the file + params := EditParams{ + FilePath: filePath, + OldString: "Initial", + NewString: "Updated", + } + + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + call := ToolCall{ + Name: EditToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.Contains(t, response.Content, "permission denied") + + // Verify file was not modified + fileContent, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, initialContent, string(fileContent)) + }) +} + +func TestGenerateDiff(t *testing.T) { + testCases := []struct { + name string + oldContent string + newContent string + expectedDiff string + }{ + { + name: "add content", + oldContent: "Line 1\nLine 2\n", + newContent: "Line 1\nLine 2\nLine 3\n", + expectedDiff: "Changes:\n Line 1\n Line 2\n+ Line 3\n", + }, + { + name: "remove content", + oldContent: "Line 1\nLine 2\nLine 3\n", + newContent: "Line 1\nLine 3\n", + expectedDiff: "Changes:\n Line 1\n- Line 2\n Line 3\n", + }, + { + name: "replace content", + oldContent: "Line 1\nLine 2\nLine 3\n", + newContent: "Line 1\nModified Line\nLine 3\n", + expectedDiff: "Changes:\n Line 1\n- Line 2\n+ Modified Line\n Line 3\n", + }, + { + name: "empty to content", + oldContent: "", + newContent: "Line 1\nLine 2\n", + expectedDiff: "Changes:\n+ Line 1\n+ Line 2\n", + }, + { + name: "content to empty", + oldContent: "Line 1\nLine 2\n", + newContent: "", + expectedDiff: "Changes:\n- Line 1\n- Line 2\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + diff := GenerateDiff(tc.oldContent, tc.newContent) + assert.Contains(t, diff, tc.expectedDiff) + }) + } +} + diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 0a852626c40d1d80f015d7cb42d3c213ca9238b0..5ea0c7633dadf578341d0e5374d773f4f6c7882b 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -15,6 +15,23 @@ import ( "github.com/kujtimiihoxha/termai/internal/permission" ) +type FetchParams struct { + URL string `json:"url"` + Format string `json:"format"` + Timeout int `json:"timeout,omitempty"` +} + +type FetchPermissionsParams struct { + URL string `json:"url"` + Format string `json:"format"` + Timeout int `json:"timeout,omitempty"` +} + +type fetchTool struct { + client *http.Client + permissions permission.Service +} + const ( FetchToolName = "fetch" fetchToolDescription = `Fetches content from a URL and returns it in the specified format. @@ -48,27 +65,12 @@ TIPS: - Set appropriate timeouts for potentially slow websites` ) -type FetchParams struct { - URL string `json:"url"` - Format string `json:"format"` - Timeout int `json:"timeout,omitempty"` -} - -type FetchPermissionsParams struct { - URL string `json:"url"` - Format string `json:"format"` - Timeout int `json:"timeout,omitempty"` -} - -type fetchTool struct { - client *http.Client -} - -func NewFetchTool() BaseTool { +func NewFetchTool(permissions permission.Service) BaseTool { return &fetchTool{ client: &http.Client{ Timeout: 30 * time.Second, }, + permissions: permissions, } } @@ -113,7 +115,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse("URL must start with http:// or https://"), nil } - p := permission.Default.Request( + p := t.permissions.Request( permission.CreatePermissionRequest{ Path: config.WorkingDirectory(), ToolName: FetchToolName, @@ -220,4 +222,3 @@ func convertHTMLToMarkdown(html string) (string, error) { return markdown, nil } - diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 05464094ba804c986a1973a5bc5c83bf043aa496..4de7971e63e088b1de953beb9bad28c5e68bb2d4 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -15,10 +15,42 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" ) -type globTool struct{} - const ( - GlobToolName = "glob" + GlobToolName = "glob" + globDescription = `Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first). + +WHEN TO USE THIS TOOL: +- Use when you need to find files by name patterns or extensions +- Great for finding specific file types across a directory structure +- Useful for discovering files that match certain naming conventions + +HOW TO USE: +- Provide a glob pattern to match against file paths +- Optionally specify a starting directory (defaults to current working directory) +- Results are sorted with most recently modified files first + +GLOB PATTERN SYNTAX: +- '*' matches any sequence of non-separator characters +- '**' matches any sequence of characters, including separators +- '?' matches any single non-separator character +- '[...]' matches any character in the brackets +- '[!...]' matches any character not in the brackets + +COMMON PATTERN EXAMPLES: +- '*.js' - Find all JavaScript files in the current directory +- '**/*.js' - Find all JavaScript files in any subdirectory +- 'src/**/*.{ts,tsx}' - Find all TypeScript files in the src directory +- '*.{html,css,js}' - Find all HTML, CSS, and JS files + +LIMITATIONS: +- Results are limited to 100 files (newest first) +- Does not search file contents (use Grep tool for that) +- Hidden files (starting with '.') are skipped + +TIPS: +- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep +- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead +- Always check if results are truncated and refine your search pattern if needed` ) type fileInfo struct { @@ -31,10 +63,16 @@ type GlobParams struct { Path string `json:"path"` } +type globTool struct{} + +func NewGlobTool() BaseTool { + return &globTool{} +} + func (g *globTool) Info() ToolInfo { return ToolInfo{ Name: GlobToolName, - Description: globDescription(), + Description: globDescription, Parameters: map[string]any{ "pattern": map[string]any{ "type": "string", @@ -49,7 +87,6 @@ func (g *globTool) Info() ToolInfo { } } -// Run implements Tool. func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params GlobParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -60,7 +97,6 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("pattern is required"), nil } - // If path is empty, use current working directory searchPath := params.Path if searchPath == "" { searchPath = config.WorkingDirectory() @@ -71,7 +107,6 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse(fmt.Sprintf("error performing glob search: %s", err)), nil } - // Format the output for the assistant var output string if len(files) == 0 { output = "No files found" @@ -86,28 +121,20 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { - // Make sure pattern starts with the search path if not absolute if !strings.HasPrefix(pattern, "/") && !strings.HasPrefix(pattern, searchPath) { - // If searchPath doesn't end with a slash, add one before appending the pattern if !strings.HasSuffix(searchPath, "/") { searchPath += "/" } pattern = searchPath + pattern } - // Open the filesystem for walking fsys := os.DirFS("/") - // Convert the absolute pattern to a relative one for the DirFS - // DirFS uses the root directory ("/") so we should strip leading "/" relPattern := strings.TrimPrefix(pattern, "/") - // Collect matching files var matches []fileInfo - // Use doublestar to walk the filesystem and find matches err := doublestar.GlobWalk(fsys, relPattern, func(path string, d fs.DirEntry) error { - // Skip directories from results if d.IsDir() { return nil } @@ -115,20 +142,17 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { return nil } - // Get file info for modification time info, err := d.Info() if err != nil { return nil // Skip files we can't access } - // Add to matches absPath := "/" + path // Restore absolute path matches = append(matches, fileInfo{ path: absPath, modTime: info.ModTime(), }) - // Check limit if len(matches) >= limit*2 { // Collect more than needed for sorting return fs.SkipAll } @@ -139,18 +163,15 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { return nil, false, fmt.Errorf("glob walk error: %w", err) } - // Sort files by modification time (newest first) sort.Slice(matches, func(i, j int) bool { return matches[i].modTime.After(matches[j].modTime) }) - // Check if we need to truncate the results truncated := len(matches) > limit if truncated { matches = matches[:limit] } - // Extract just the paths results := make([]string, len(matches)) for i, m := range matches { results[i] = m.path @@ -163,44 +184,3 @@ func skipHidden(path string) bool { base := filepath.Base(path) return base != "." && strings.HasPrefix(base, ".") } - -func globDescription() string { - return `Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first). - -WHEN TO USE THIS TOOL: -- Use when you need to find files by name patterns or extensions -- Great for finding specific file types across a directory structure -- Useful for discovering files that match certain naming conventions - -HOW TO USE: -- Provide a glob pattern to match against file paths -- Optionally specify a starting directory (defaults to current working directory) -- Results are sorted with most recently modified files first - -GLOB PATTERN SYNTAX: -- '*' matches any sequence of non-separator characters -- '**' matches any sequence of characters, including separators -- '?' matches any single non-separator character -- '[...]' matches any character in the brackets -- '[!...]' matches any character not in the brackets - -COMMON PATTERN EXAMPLES: -- '*.js' - Find all JavaScript files in the current directory -- '**/*.js' - Find all JavaScript files in any subdirectory -- 'src/**/*.{ts,tsx}' - Find all TypeScript files in the src directory -- '*.{html,css,js}' - Find all HTML, CSS, and JS files - -LIMITATIONS: -- Results are limited to 100 files (newest first) -- Does not search file contents (use Grep tool for that) -- Hidden files (starting with '.') are skipped - -TIPS: -- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep -- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed` -} - -func NewGlobTool() BaseTool { - return &globTool{} -} diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index af58dacf01468f46ee2aa6f110758790b7011afb..f349e83709dc4ec5f6b3bddfb3851b2d01b6a441 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -16,12 +16,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" ) -type grepTool struct{} - -const ( - GrepToolName = "grep" -) - type GrepParams struct { Pattern string `json:"pattern"` Path string `json:"path"` @@ -33,10 +27,54 @@ type grepMatch struct { modTime time.Time } +type grepTool struct{} + +const ( + GrepToolName = "grep" + grepDescription = `Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first). + +WHEN TO USE THIS TOOL: +- Use when you need to find files containing specific text or patterns +- Great for searching code bases for function names, variable declarations, or error messages +- Useful for finding all files that use a particular API or pattern + +HOW TO USE: +- Provide a regex pattern to search for within file contents +- Optionally specify a starting directory (defaults to current working directory) +- Optionally provide an include pattern to filter which files to search +- Results are sorted with most recently modified files first + +REGEX PATTERN SYNTAX: +- Supports standard regular expression syntax +- 'function' searches for the literal text "function" +- 'log\..*Error' finds text starting with "log." and ending with "Error" +- 'import\s+.*\s+from' finds import statements in JavaScript/TypeScript + +COMMON INCLUDE PATTERN EXAMPLES: +- '*.js' - Only search JavaScript files +- '*.{ts,tsx}' - Only search TypeScript files +- '*.go' - Only search Go files + +LIMITATIONS: +- Results are limited to 100 files (newest first) +- Performance depends on the number of files being searched +- Very large binary files may be skipped +- Hidden files (starting with '.') are skipped + +TIPS: +- For faster, more targeted searches, first use Glob to find relevant files, then use Grep +- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead +- Always check if results are truncated and refine your search pattern if needed` +) + +func NewGrepTool() BaseTool { + return &grepTool{} +} + func (g *grepTool) Info() ToolInfo { return ToolInfo{ Name: GrepToolName, - Description: grepDescription(), + Description: grepDescription, Parameters: map[string]any{ "pattern": map[string]any{ "type": "string", @@ -55,7 +93,6 @@ func (g *grepTool) Info() ToolInfo { } } -// Run implements Tool. func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params GrepParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -66,7 +103,6 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("pattern is required"), nil } - // If path is empty, use current working directory searchPath := params.Path if searchPath == "" { searchPath = config.WorkingDirectory() @@ -77,7 +113,6 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse(fmt.Sprintf("error searching files: %s", err)), nil } - // Format the output for the assistant var output string if len(matches) == 0 { output = "No files found" @@ -103,28 +138,23 @@ func pluralize(count int) string { } func searchFiles(pattern, rootPath, include string, limit int) ([]string, bool, error) { - // First try using ripgrep if available for better performance matches, err := searchWithRipgrep(pattern, rootPath, include) if err != nil { - // Fall back to manual regex search if ripgrep is not available matches, err = searchFilesWithRegex(pattern, rootPath, include) if err != nil { return nil, false, err } } - // Sort files by modification time (newest first) sort.Slice(matches, func(i, j int) bool { return matches[i].modTime.After(matches[j].modTime) }) - // Check if we need to truncate the results truncated := len(matches) > limit if truncated { matches = matches[:limit] } - // Extract just the paths results := make([]string, len(matches)) for i, m := range matches { results[i] = m.path @@ -149,7 +179,6 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) { output, err := cmd.Output() if err != nil { if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { - // Exit code 1 means no matches, which isn't an error for our purposes return []grepMatch{}, nil } return nil, err @@ -203,17 +232,14 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error return nil // Skip directories } - // Skip hidden files if skipHidden(path) { return nil } - // Check include pattern if provided if includePattern != nil && !includePattern.MatchString(path) { return nil } - // Check file contents for the pattern match, err := fileContainsPattern(path, regex) if err != nil { return nil // Skip files we can't read @@ -225,7 +251,6 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error modTime: info.ModTime(), }) - // Check if we've hit the limit (collect double for sorting) if len(matches) >= 200 { return filepath.SkipAll } @@ -270,44 +295,3 @@ func globToRegex(glob string) string { return regexPattern } - -func grepDescription() string { - return `Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first). - -WHEN TO USE THIS TOOL: -- Use when you need to find files containing specific text or patterns -- Great for searching code bases for function names, variable declarations, or error messages -- Useful for finding all files that use a particular API or pattern - -HOW TO USE: -- Provide a regex pattern to search for within file contents -- Optionally specify a starting directory (defaults to current working directory) -- Optionally provide an include pattern to filter which files to search -- Results are sorted with most recently modified files first - -REGEX PATTERN SYNTAX: -- Supports standard regular expression syntax -- 'function' searches for the literal text "function" -- 'log\..*Error' finds text starting with "log." and ending with "Error" -- 'import\s+.*\s+from' finds import statements in JavaScript/TypeScript - -COMMON INCLUDE PATTERN EXAMPLES: -- '*.js' - Only search JavaScript files -- '*.{ts,tsx}' - Only search TypeScript files -- '*.go' - Only search Go files - -LIMITATIONS: -- Results are limited to 100 files (newest first) -- Performance depends on the number of files being searched -- Very large binary files may be skipped -- Hidden files (starting with '.') are skipped - -TIPS: -- For faster, more targeted searches, first use Glob to find relevant files, then use Grep -- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed` -} - -func NewGrepTool() BaseTool { - return &grepTool{} -} diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a3e0035622a1a926661e2a0bf39bc723e5e1d6a1..59e8dcd21141a33491aaac96a50f78607a4293b0 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -11,13 +11,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" ) -type lsTool struct{} - -const ( - LSToolName = "ls" - MaxLSFiles = 1000 -) - type LSParams struct { Path string `json:"path"` Ignore []string `json:"ignore"` @@ -30,10 +23,49 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } +type lsTool struct{} + +const ( + LSToolName = "ls" + MaxLSFiles = 1000 + lsDescription = `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization. + +WHEN TO USE THIS TOOL: +- Use when you need to explore the structure of a directory +- Helpful for understanding the organization of a project +- Good first step when getting familiar with a new codebase + +HOW TO USE: +- Provide a path to list (defaults to current working directory) +- Optionally specify glob patterns to ignore +- Results are displayed in a tree structure + +FEATURES: +- Displays a hierarchical view of files and directories +- Automatically skips hidden files/directories (starting with '.') +- Skips common system directories like __pycache__ +- Can filter out files matching specific patterns + +LIMITATIONS: +- Results are limited to 1000 files +- Very large directories will be truncated +- Does not show file sizes or permissions +- Cannot recursively list all directories in a large project + +TIPS: +- Use Glob tool for finding files by name patterns instead of browsing +- Use Grep tool for searching file contents +- Combine with other tools for more effective exploration` +) + +func NewLsTool() BaseTool { + return &lsTool{} +} + func (l *lsTool) Info() ToolInfo { return ToolInfo{ Name: LSToolName, - Description: lsDescription(), + Description: lsDescription, Parameters: map[string]any{ "path": map[string]any{ "type": "string", @@ -51,25 +83,21 @@ func (l *lsTool) Info() ToolInfo { } } -// Run implements Tool. func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params LSParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } - // If path is empty, use current working directory searchPath := params.Path if searchPath == "" { searchPath = config.WorkingDirectory() } - // Ensure the path is absolute if !filepath.IsAbs(searchPath) { searchPath = filepath.Join(config.WorkingDirectory(), searchPath) } - // Check if the path exists if _, err := os.Stat(searchPath); os.IsNotExist(err) { return NewTextErrorResponse(fmt.Sprintf("path does not exist: %s", searchPath)), nil } @@ -129,12 +157,10 @@ func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]st func shouldSkip(path string, ignorePatterns []string) bool { base := filepath.Base(path) - // Skip hidden files and directories if base != "." && strings.HasPrefix(base, ".") { return true } - // Skip common directories and files commonIgnored := []string{ "__pycache__", "node_modules", @@ -156,32 +182,26 @@ func shouldSkip(path string, ignorePatterns []string) bool { "*.exe", } - // Skip __pycache__ directories if strings.Contains(path, filepath.Join("__pycache__", "")) { return true } - // Check against common ignored patterns for _, ignored := range commonIgnored { if strings.HasSuffix(ignored, "/") { - // Directory pattern if strings.Contains(path, filepath.Join(ignored[:len(ignored)-1], "")) { return true } } else if strings.HasPrefix(ignored, "*.") { - // File extension pattern if strings.HasSuffix(base, ignored[1:]) { return true } } else { - // Exact match if base == ignored { return true } } } - // Check against ignore patterns for _, pattern := range ignorePatterns { matched, err := filepath.Match(pattern, base) if err == nil && matched { @@ -283,38 +303,3 @@ func printNode(builder *strings.Builder, node *TreeNode, level int) { } } } - -func lsDescription() string { - return `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization. - -WHEN TO USE THIS TOOL: -- Use when you need to explore the structure of a directory -- Helpful for understanding the organization of a project -- Good first step when getting familiar with a new codebase - -HOW TO USE: -- Provide a path to list (defaults to current working directory) -- Optionally specify glob patterns to ignore -- Results are displayed in a tree structure - -FEATURES: -- Displays a hierarchical view of files and directories -- Automatically skips hidden files/directories (starting with '.') -- Skips common system directories like __pycache__ -- Can filter out files matching specific patterns - -LIMITATIONS: -- Results are limited to 1000 files -- Very large directories will be truncated -- Does not show file sizes or permissions -- Cannot recursively list all directories in a large project - -TIPS: -- Use Glob tool for finding files by name patterns instead of browsing -- Use Grep tool for searching file contents -- Combine with other tools for more effective exploration` -} - -func NewLsTool() BaseTool { - return &lsTool{} -} diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index f20ce8a62517e1a7bceae4b0cebc42766efde31d..d23a6af6ddd6eb77a6ccb25893fbeadc68bee736 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -11,6 +11,17 @@ import ( "time" ) +type SourcegraphParams struct { + Query string `json:"query"` + Count int `json:"count,omitempty"` + ContextWindow int `json:"context_window,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +type sourcegraphTool struct { + client *http.Client +} + const ( SourcegraphToolName = "sourcegraph" sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API. @@ -110,17 +121,6 @@ TIPS: - For more details on query syntax, visit: https://docs.sourcegraph.com/code_search/queries` ) -type SourcegraphParams struct { - Query string `json:"query"` - Count int `json:"count,omitempty"` - ContextWindow int `json:"context_window,omitempty"` - Timeout int `json:"timeout,omitempty"` -} - -type sourcegraphTool struct { - client *http.Client -} - func NewSourcegraphTool() BaseTool { return &sourcegraphTool{ client: &http.Client{ @@ -165,7 +165,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, return NewTextErrorResponse("Query parameter is required"), nil } - // Set default count if not specified if params.Count <= 0 { params.Count = 10 } else if params.Count > 20 { @@ -186,8 +185,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } } - // GraphQL query for Sourcegraph search - // Create a properly escaped JSON structure type graphqlRequest struct { Query string `json:"query"` Variables struct { @@ -200,14 +197,12 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } request.Variables.Query = params.Query - // Marshal to JSON to ensure proper escaping graphqlQueryBytes, err := json.Marshal(request) if err != nil { return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil } graphqlQuery := string(graphqlQueryBytes) - // Create request to Sourcegraph API req, err := http.NewRequestWithContext( ctx, "POST", @@ -228,7 +223,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - // log the error response body, _ := io.ReadAll(resp.Body) if len(body) > 0 { return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil @@ -241,13 +235,11 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil } - // Parse the GraphQL response var result map[string]any if err = json.Unmarshal(body, &result); err != nil { return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil } - // Format the results in a readable way formattedResults, err := formatSourcegraphResults(result, params.ContextWindow) if err != nil { return NewTextErrorResponse("Failed to format results: " + err.Error()), nil @@ -259,7 +251,6 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) { var buffer strings.Builder - // Check for errors in the GraphQL response if errors, ok := result["errors"].([]any); ok && len(errors) > 0 { buffer.WriteString("## Sourcegraph API Error\n\n") for _, err := range errors { @@ -272,7 +263,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, return buffer.String(), nil } - // Extract data from the response data, ok := result["data"].(map[string]any) if !ok { return "", fmt.Errorf("invalid response format: missing data field") @@ -288,7 +278,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, return "", fmt.Errorf("invalid response format: missing results field") } - // Write search metadata matchCount, _ := searchResults["matchCount"].(float64) resultCount, _ := searchResults["resultCount"].(float64) limitHit, _ := searchResults["limitHit"].(bool) @@ -302,33 +291,28 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, buffer.WriteString("\n") - // Process results results, ok := searchResults["results"].([]any) if !ok || len(results) == 0 { buffer.WriteString("No results found. Try a different query.\n") return buffer.String(), nil } - // Limit to 10 results maxResults := 10 if len(results) > maxResults { results = results[:maxResults] } - // Process each result for i, res := range results { fileMatch, ok := res.(map[string]any) if !ok { continue } - // Skip non-FileMatch results typeName, _ := fileMatch["__typename"].(string) if typeName != "FileMatch" { continue } - // Extract repository and file information repo, _ := fileMatch["repository"].(map[string]any) file, _ := fileMatch["file"].(map[string]any) lineMatches, _ := fileMatch["lineMatches"].([]any) @@ -348,7 +332,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL)) } - // Show line matches with context if len(lineMatches) > 0 { for _, lm := range lineMatches { lineMatch, ok := lm.(map[string]any) @@ -359,13 +342,11 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, lineNumber, _ := lineMatch["lineNumber"].(float64) preview, _ := lineMatch["preview"].(string) - // Extract context from file content if available if fileContent != "" { lines := strings.Split(fileContent, "\n") buffer.WriteString("```\n") - // Display context before the match (up to 10 lines) startLine := max(1, int(lineNumber)-contextWindow) for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ { @@ -374,10 +355,8 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, } } - // Display the matching line (highlighted) buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview)) - // Display context after the match (up to 10 lines) endLine := int(lineNumber) + contextWindow for j := int(lineNumber); j < endLine && j < len(lines); j++ { @@ -388,7 +367,6 @@ func formatSourcegraphResults(result map[string]any, contextWindow int) (string, buffer.WriteString("```\n\n") } else { - // If file content is not available, just show the preview buffer.WriteString("```\n") buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview)) buffer.WriteString("```\n\n") diff --git a/internal/llm/tools/sourcegraph_test.go b/internal/llm/tools/sourcegraph_test.go index 5657ccd7eacaa8adac61c2e88e5153c8a23383d4..89829aefcb1bffd399f7cc4f738a9c32cb7b6049 100644 --- a/internal/llm/tools/sourcegraph_test.go +++ b/internal/llm/tools/sourcegraph_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "testing" - "github.com/kujtimiihoxha/termai/internal/permission" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,13 +22,6 @@ func TestSourcegraphTool_Info(t *testing.T) { } func TestSourcegraphTool_Run(t *testing.T) { - // Setup a mock permission handler that always allows - origPermission := permission.Default - defer func() { - permission.Default = origPermission - }() - permission.Default = newMockPermissionService(true) - t.Run("handles missing query parameter", func(t *testing.T) { tool := NewSourcegraphTool() params := SourcegraphParams{ @@ -61,27 +53,6 @@ func TestSourcegraphTool_Run(t *testing.T) { assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters") }) - t.Run("handles permission denied", func(t *testing.T) { - permission.Default = newMockPermissionService(false) - - tool := NewSourcegraphTool() - params := SourcegraphParams{ - Query: "test query", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: SourcegraphToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Permission denied") - }) - t.Run("normalizes count parameter", func(t *testing.T) { // Test cases for count normalization testCases := []struct { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 743cef6f42ca9b69d10465cd8353a7d4d2d9f554..a687be015353afdbdc1e84fe1eaf384a72aa350e 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -14,6 +14,12 @@ import ( "github.com/kujtimiihoxha/termai/internal/lsp" ) +type ViewParams struct { + FilePath string `json:"file_path"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + type viewTool struct { lspClients map[string]*lsp.Client } @@ -23,18 +29,48 @@ const ( MaxReadSize = 250 * 1024 DefaultReadLimit = 2000 MaxLineLength = 2000 + viewDescription = `File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data. + +WHEN TO USE THIS TOOL: +- Use when you need to read the contents of a specific file +- Helpful for examining source code, configuration files, or log files +- Perfect for looking at text-based file formats + +HOW TO USE: +- Provide the path to the file you want to view +- Optionally specify an offset to start reading from a specific line +- Optionally specify a limit to control how many lines are read + +FEATURES: +- Displays file contents with line numbers for easy reference +- Can read from any position in a file using the offset parameter +- Handles large files by limiting the number of lines read +- Automatically truncates very long lines for better display +- Suggests similar file names when the requested file isn't found + +LIMITATIONS: +- Maximum file size is 250KB +- Default reading limit is 2000 lines +- Lines longer than 2000 characters are truncated +- Cannot display binary files or images +- Images can be identified but not displayed + +TIPS: +- Use with Glob tool to first find files you want to view +- For code exploration, first use Grep to find relevant files, then View to examine them +- When viewing large files, use the offset parameter to read specific sections` ) -type ViewParams struct { - FilePath string `json:"file_path"` - Offset int `json:"offset"` - Limit int `json:"limit"` +func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { + return &viewTool{ + lspClients, + } } func (v *viewTool) Info() ToolInfo { return ToolInfo{ Name: ViewToolName, - Description: viewDescription(), + Description: viewDescription, Parameters: map[string]any{ "file_path": map[string]any{ "type": "string", @@ -262,42 +298,3 @@ func (s *LineScanner) Text() string { func (s *LineScanner) Err() error { return s.scanner.Err() } - -func viewDescription() string { - return `File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data. - -WHEN TO USE THIS TOOL: -- Use when you need to read the contents of a specific file -- Helpful for examining source code, configuration files, or log files -- Perfect for looking at text-based file formats - -HOW TO USE: -- Provide the path to the file you want to view -- Optionally specify an offset to start reading from a specific line -- Optionally specify a limit to control how many lines are read - -FEATURES: -- Displays file contents with line numbers for easy reference -- Can read from any position in a file using the offset parameter -- Handles large files by limiting the number of lines read -- Automatically truncates very long lines for better display -- Suggests similar file names when the requested file isn't found - -LIMITATIONS: -- Maximum file size is 250KB -- Default reading limit is 2000 lines -- Lines longer than 2000 characters are truncated -- Cannot display binary files or images -- Images can be identified but not displayed - -TIPS: -- Use with Glob tool to first find files you want to view -- For code exploration, first use Grep to find relevant files, then View to examine them -- When viewing large files, use the offset parameter to read specific sections` -} - -func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { - return &viewTool{ - lspClients, - } -} diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 86c9be37ecb50fcf01ee862c641a2d0f9c70cd29..7b698d2d8dcc5de7bbd32d26c9e94749fd5ed996 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -13,14 +13,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/permission" ) -type writeTool struct { - lspClients map[string]*lsp.Client -} - -const ( - WriteToolName = "write" -) - type WriteParams struct { FilePath string `json:"file_path"` Content string `json:"content"` @@ -31,10 +23,54 @@ type WritePermissionsParams struct { Content string `json:"content"` } +type writeTool struct { + lspClients map[string]*lsp.Client + permissions permission.Service +} + +const ( + WriteToolName = "write" + writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. + +WHEN TO USE THIS TOOL: +- Use when you need to create a new file +- Helpful for updating existing files with modified content +- Perfect for saving generated code, configurations, or text data + +HOW TO USE: +- Provide the path to the file you want to write +- Include the content to be written to the file +- The tool will create any necessary parent directories + +FEATURES: +- Can create new files or overwrite existing ones +- Creates parent directories automatically if they don't exist +- Checks if the file has been modified since last read for safety +- Avoids unnecessary writes when content hasn't changed + +LIMITATIONS: +- You should read a file before writing to it to avoid conflicts +- Cannot append to files (rewrites the entire file) + + +TIPS: +- Use the View tool first to examine existing files before modifying them +- Use the LS tool to verify the correct location when creating new files +- Combine with Glob and Grep tools to find and modify multiple files +- Always include descriptive comments when making changes to existing code` +) + +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { + return &writeTool{ + lspClients: lspClients, + permissions: permissions, + } +} + func (w *writeTool) Info() ToolInfo { return ToolInfo{ Name: WriteToolName, - Description: writeDescription(), + Description: writeDescription, Parameters: map[string]any{ "file_path": map[string]any{ "type": "string", @@ -49,7 +85,6 @@ func (w *writeTool) Info() ToolInfo { } } -// Run implements Tool. func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params WriteParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -64,20 +99,17 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse("content is required"), nil } - // Handle relative paths filePath := params.FilePath if !filepath.IsAbs(filePath) { filePath = filepath.Join(config.WorkingDirectory(), filePath) } - // Check if file exists and is a directory fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil } - // Check if file was modified since last read modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { @@ -85,7 +117,6 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil } - // Optional: Get old content for diff oldContent, readErr := os.ReadFile(filePath) if readErr == nil && string(oldContent) == params.Content { return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil @@ -94,13 +125,11 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil } - // Create parent directories if needed dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil } - // Get old content for diff if file exists oldContent := "" if fileInfo != nil && !fileInfo.IsDir() { oldBytes, readErr := os.ReadFile(filePath) @@ -108,8 +137,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error oldContent = string(oldBytes) } } - - p := permission.Default.Request( + + p := w.permissions.Request( permission.CreatePermissionRequest{ Path: filePath, ToolName: WriteToolName, @@ -125,16 +154,13 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse(fmt.Sprintf("Permission denied to create file: %s", filePath)), nil } - // Write the file err = os.WriteFile(filePath, []byte(params.Content), 0o644) if err != nil { return NewTextErrorResponse(fmt.Sprintf("Failed to write file: %s", err)), nil } - // Record the file write recordFileWrite(filePath) recordFileRead(filePath) - // Wait for LSP diagnostics after writing the file waitForLspDiagnostics(ctx, filePath, w.lspClients) result := fmt.Sprintf("File successfully written: %s", filePath) @@ -142,40 +168,3 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result += appendDiagnostics(filePath, w.lspClients) return NewTextResponse(result), nil } - -func writeDescription() string { - return `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. - -WHEN TO USE THIS TOOL: -- Use when you need to create a new file -- Helpful for updating existing files with modified content -- Perfect for saving generated code, configurations, or text data - -HOW TO USE: -- Provide the path to the file you want to write -- Include the content to be written to the file -- The tool will create any necessary parent directories - -FEATURES: -- Can create new files or overwrite existing ones -- Creates parent directories automatically if they don't exist -- Checks if the file has been modified since last read for safety -- Avoids unnecessary writes when content hasn't changed - -LIMITATIONS: -- You should read a file before writing to it to avoid conflicts -- Cannot append to files (rewrites the entire file) - - -TIPS: -- Use the View tool first to examine existing files before modifying them -- Use the LS tool to verify the correct location when creating new files -- Combine with Glob and Grep tools to find and modify multiple files -- Always include descriptive comments when making changes to existing code` -} - -func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool { - return &writeTool{ - lspClients, - } -} diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 893a48b620597bb52e832cb35330798d309f9cc2..50dafc14f0f8a42b1a848a3f5c1d9b530a59e196 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -9,13 +9,12 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -27,21 +26,13 @@ func TestWriteTool_Info(t *testing.T) { } func TestWriteTool_Run(t *testing.T) { - // Setup a mock permission handler that always allows - origPermission := permission.Default - defer func() { - permission.Default = origPermission - }() - permission.Default = newMockPermissionService(true) - // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "write_tool_test") require.NoError(t, err) defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -70,8 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -100,8 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -138,8 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) call := ToolCall{ Name: WriteToolName, @@ -152,8 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) params := WriteParams{ FilePath: "", @@ -174,8 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -196,8 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -223,8 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - permission.Default = newMockPermissionService(false) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -250,8 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -292,8 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - permission.Default = newMockPermissionService(true) - tool := NewWriteTool(make(map[string]*lsp.Client)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") @@ -323,4 +305,3 @@ func TestWriteTool_Run(t *testing.T) { assert.Contains(t, response.Content, "already contains the exact content") }) } - diff --git a/internal/permission/permission.go b/internal/permission/permission.go index fbdef0e96ed4d169e5e9c67f7bcfdaa8f1f102dc..ebf3fe0925b719afbe94f058ac6195611ed17341 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -100,5 +100,3 @@ func NewPermissionService() Service { sessionPermissions: make([]PermissionRequest, 0), } } - -var Default Service = NewPermissionService() diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 8785b5ab230faaf9224e9632c8ac37ec0fc2b744..15d97a11a024191007cbe064c7a2d0c580677c1e 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -80,11 +80,11 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: - permission.Default.Grant(msg.Permission) + a.app.Permissions.Grant(msg.Permission) case dialog.PermissionAllowForSession: - permission.Default.GrantPersistant(msg.Permission) + a.app.Permissions.GrantPersistant(msg.Permission) case dialog.PermissionDeny: - permission.Default.Deny(msg.Permission) + a.app.Permissions.Deny(msg.Permission) } case vimtea.EditorModeMsg: a.editorMode = msg.Mode