From 64199736677a8ed303291c763975d484ac8f7e99 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 4 Apr 2025 15:41:25 +0200 Subject: [PATCH] Enhance bash tool security and improve permission dialog UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Expand safe command list with common dev tools (git, go, node, python, etc.) - Improve multi-word command detection for better security checks - Add scrollable viewport to permission dialog for better diff viewing - Fix command batching in TUI update to properly handle multiple commands 🤖 Generated with termai Co-Authored-By: termai --- internal/llm/tools/bash.go | 47 ++++++++++- internal/llm/tools/bash_test.go | 37 ++++++--- internal/tui/components/dialog/permission.go | 87 ++++++++++++++++---- internal/tui/tui.go | 15 ++-- 4 files changed, 149 insertions(+), 37 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 1481bdd88d90748c2419fd80c8d34782ab978b79..a78c03215c3dee7381e84970855f93b5c29147d2 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -38,8 +38,38 @@ var BannedCommands = []string{ } var SafeReadOnlyCommands = []string{ + // Basic shell commands "ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis", - "whatis", //... + "whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout", + + // Git read-only commands + "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", + "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", + + // Go commands + "go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", + + // Node.js commands + "node", "npm", "npx", "yarn", "pnpm", + + // Python commands + "python", "python3", "pip", "pip3", "pytest", "pylint", "mypy", "black", "isort", "flake8", "ruff", + + // Docker commands + "docker ps", "docker images", "docker volume", "docker network", "docker info", "docker version", + "docker-compose ps", "docker-compose config", + + // Kubernetes commands + "kubectl get", "kubectl describe", "kubectl logs", "kubectl version", "kubectl config", + + // Rust commands + "cargo", "rustc", "rustup", + + // Java commands + "java", "javac", "mvn", "gradle", + + // Misc development tools + "make", "cmake", "bazel", "terraform plan", "terraform validate", "ansible", } func (b *bashTool) Info() ToolInfo { @@ -77,17 +107,26 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } + // Check for banned commands (first word only) baseCmd := strings.Fields(params.Command)[0] for _, banned := range BannedCommands { if strings.EqualFold(baseCmd, banned) { return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil } } + + // Check for safe commands (can be multi-word) isSafeReadOnly := false + cmdLower := strings.ToLower(params.Command) + for _, safe := range SafeReadOnlyCommands { - if strings.EqualFold(baseCmd, safe) { - isSafeReadOnly = true - break + // Check if command starts with the safe command pattern + if strings.HasPrefix(cmdLower, strings.ToLower(safe)) { + // Make sure it's either an exact match or followed by a space or flag + if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { + isSafeReadOnly = true + break + } } } if !isSafeReadOnly { diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index b7b5c5ee5b479e6693e2930afd7e5bae220e9c83..9eadc227ce12edcca692b68ea412050afaaa06fd 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -119,27 +119,38 @@ func TestBashTool_Run(t *testing.T) { } }) - t.Run("handles safe read-only commands without permission check", func(t *testing.T) { + t.Run("handles multi-word safe commands without permission check", func(t *testing.T) { permission.Default = newMockPermissionService(false) tool := NewBashTool() - // Test with a safe read-only command - params := BashParams{ - Command: "echo 'test'", + // Test with multi-word safe commands + multiWordCommands := []string{ + "git status", + "git log -n 5", + "docker ps", + "go test ./...", + "kubectl get pods", } - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + for _, cmd := range multiWordCommands { + params := BashParams{ + Command: cmd, + } - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Equal(t, "test\n", response.Content) + call := ToolCall{ + Name: BashToolName, + Input: string(paramsJSON), + } + + response, err := tool.Run(context.Background(), call) + require.NoError(t, err) + assert.NotContains(t, response.Content, "permission denied", + "Command %s should be allowed without permission", cmd) + } }) t.Run("handles permission denied", func(t *testing.T) { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 085744dda0cee646b43f6f522741213c9b860960..17b3fba078151abae873349650b720c8126c9835 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -92,16 +92,7 @@ func formatDiff(diffText string) string { } // Join all formatted lines - content := strings.Join(formattedLines, "\n") - - // Create a bordered box for the content - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) - - return contentStyle.Render(content) + return strings.Join(formattedLines, "\n") } func (p *permissionDialogCmp) Init() tea.Cmd { @@ -241,12 +232,46 @@ func (p *permissionDialogCmp) render() string { headerParts = append(headerParts, keyStyle.Render("Update")) // Recreate header content with the updated headerParts headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - // Format the diff with colors instead of using markdown code block + + // Format the diff with colors formattedDiff := formatDiff(pr.Diff) + + // Set up viewport for the diff content + p.contentViewPort.Width = p.width - 2 - 2 + + // Calculate content height dynamically based on window size + maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 + p.contentViewPort.Height = maxContentHeight + p.contentViewPort.SetContent(formattedDiff) + + // Style the viewport + var contentBorder lipgloss.Border + var borderColor lipgloss.TerminalColor + + if p.isViewportFocus { + contentBorder = lipgloss.DoubleBorder() + borderColor = styles.Blue + } else { + contentBorder = lipgloss.RoundedBorder() + borderColor = styles.Flamingo + } + + contentStyle := lipgloss.NewStyle(). + MarginTop(1). + Padding(0, 1). + Border(contentBorder). + BorderForeground(borderColor) + + if p.isViewportFocus { + contentStyle = contentStyle.BorderBackground(styles.Surface0) + } + + contentFinal := contentStyle.Render(p.contentViewPort.View()) + return lipgloss.JoinVertical( lipgloss.Top, headerContent, - formattedDiff, + contentFinal, form, ) @@ -255,12 +280,46 @@ func (p *permissionDialogCmp) render() string { headerParts = append(headerParts, keyStyle.Render("Content")) // Recreate header content with the updated headerParts headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - // Format the diff with colors instead of using markdown code block + + // Format the diff with colors formattedDiff := formatDiff(pr.Content) + + // Set up viewport for the content + p.contentViewPort.Width = p.width - 2 - 2 + + // Calculate content height dynamically based on window size + maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 + p.contentViewPort.Height = maxContentHeight + p.contentViewPort.SetContent(formattedDiff) + + // Style the viewport + var contentBorder lipgloss.Border + var borderColor lipgloss.TerminalColor + + if p.isViewportFocus { + contentBorder = lipgloss.DoubleBorder() + borderColor = styles.Blue + } else { + contentBorder = lipgloss.RoundedBorder() + borderColor = styles.Flamingo + } + + contentStyle := lipgloss.NewStyle(). + MarginTop(1). + Padding(0, 1). + Border(contentBorder). + BorderForeground(borderColor) + + if p.isViewportFocus { + contentStyle = contentStyle.BorderBackground(styles.Surface0) + } + + contentFinal := contentStyle.Render(p.contentViewPort.View()) + return lipgloss.JoinVertical( lipgloss.Top, headerContent, - formattedDiff, + contentFinal, form, ) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 97f876a2d945ae6af2fdd8de50403155c022cd30..8785b5ab230faaf9224e9632c8ac37ec0fc2b744 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -123,8 +123,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.status, _ = a.status.Update(msg) case util.ErrorMsg: a.status, _ = a.status.Update(msg) - case util.ClearStatusMsg: - a.status, _ = a.status.Update(msg) case tea.KeyMsg: if a.editorMode == vimtea.ModeNormal { switch { @@ -163,16 +161,21 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } } + + var cmds []tea.Cmd + s, cmd := a.status.Update(msg) + a.status = s + cmds = append(cmds, cmd) if a.dialogVisible { d, cmd := a.dialog.Update(msg) a.dialog = d.(core.DialogCmp) - return a, cmd + cmds = append(cmds, cmd) + return a, tea.Batch(cmds...) } - s, _ := a.status.Update(msg) - a.status = s p, cmd := a.pages[a.currentPage].Update(msg) a.pages[a.currentPage] = p - return a, cmd + cmds = append(cmds, cmd) + return a, tea.Batch(cmds...) } func (a *appModel) ToggleHelp() {