diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index ea736e81bc05a67d1b42bb5927f537b06f4ada5f..0a10568a39315f6c4077385b8ca83f6b3e52691c 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,7 +10,6 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" - "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -41,235 +40,10 @@ const ( BashNoOutput = "no output" ) -func containsBannedCommand(node syntax.Node) bool { - if node == nil { - return false - } - - switch n := node.(type) { - case *syntax.CallExpr: - if len(n.Args) > 0 { - cmdName := getWordValue(n.Args[0]) - for _, banned := range bannedCommands { - if strings.EqualFold(cmdName, banned) { - return true - } - } - } - for _, arg := range n.Args { - if containsBannedCommand(arg) { - return true - } - } - case *syntax.Word: - if checkWordForBannedCommands(n) { - return true - } - for _, part := range n.Parts { - if containsBannedCommand(part) { - return true - } - } - case *syntax.CmdSubst: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Subshell: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Stmt: - if containsBannedCommand(n.Cmd) { - return true - } - for _, redir := range n.Redirs { - if containsBannedCommand(redir) { - return true - } - } - case *syntax.BinaryCmd: - return containsBannedCommand(n.X) || containsBannedCommand(n.Y) - case *syntax.Block: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.IfClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Then { - if containsBannedCommand(stmt) { - return true - } - } - if n.Else != nil && containsBannedCommand(n.Else) { - return true - } - case *syntax.WhileClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.ForClause: - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - if containsBannedCommand(n.Loop) { - return true - } - case *syntax.CaseClause: - for _, item := range n.Items { - for _, stmt := range item.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - } - case *syntax.FuncDecl: - return containsBannedCommand(n.Body) - case *syntax.ArithmExp: - return containsBannedCommand(n.X) - case *syntax.Redirect: - return containsBannedCommand(n.Word) - } - return false -} - -func checkWordForBannedCommands(word *syntax.Word) bool { - if word == nil { - return false - } - - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.SglQuoted: - if checkQuotedStringForBannedCommands(p.Value) { - return true - } - case *syntax.DblQuoted: - var content strings.Builder - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - content.WriteString(lit.Value) - } - } - if checkQuotedStringForBannedCommands(content.String()) { - return true - } - } - } - return false -} - -func checkQuotedStringForBannedCommands(content string) bool { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(content), "") - if err != nil { - return false - } - - if len(file.Stmts) == 0 { - return false - } - - // Simple heuristic: if it looks like prose rather than commands, don't flag it - if len(file.Stmts) == 1 { - stmt := file.Stmts[0] - if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { - if len(callExpr.Args) > 2 { - allText := true - for i, arg := range callExpr.Args { - if i == 0 { - continue - } - argStr := getWordValue(arg) - if strings.HasPrefix(argStr, "-") { - allText = false - break - } - } - if allText { - return false - } - } - } - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - return false -} - -func getWordValue(word *syntax.Word) string { - if word == nil || len(word.Parts) == 0 { - return "" - } - - var result strings.Builder - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.Lit: - result.WriteString(p.Value) - case *syntax.SglQuoted: - result.WriteString(p.Value) - case *syntax.DblQuoted: - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - result.WriteString(lit.Value) - } - } - } - } - return result.String() -} - -func validateCommand(command string) error { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(command), "") - if err != nil { - parts := strings.Fields(command) - if len(parts) > 0 { - baseCmd := parts[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return fmt.Errorf("command '%s' is not allowed", baseCmd) - } - } - } - return nil - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return fmt.Errorf("command contains banned operations") - } - } - return nil -} - var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", "sudo", + "http-prompt", "chrome", "firefox", "safari", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -515,9 +289,11 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - - if err := validateCommand(params.Command); err != nil { - return NewTextErrorResponse(err.Error()), nil + baseCmd := strings.Fields(params.Command)[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil + } } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 768a3738dd2fe162838f32ee498ac14ed2ba9eee..a810002749408af2bb89cb958b5999dc2da3bcb3 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,168 +94,3 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } - -func TestValidateCommand(t *testing.T) { - tests := []struct { - name string - command string - shouldError bool - }{ - // Commands that should be blocked - { - name: "direct sudo", - command: "sudo ls", - shouldError: true, - }, - { - name: "sudo in script", - command: "bash -c 'sudo ls'", - shouldError: true, - }, - { - name: "sudo in command substitution", - command: "$(sudo whoami)", - shouldError: true, - }, - { - name: "sudo in echo command substitution", - command: "echo $(sudo id)", - shouldError: true, - }, - { - name: "sudo in command chain", - command: "ls && sudo rm file", - shouldError: true, - }, - { - name: "sudo in if statement", - command: "if true; then sudo ls; fi", - shouldError: true, - }, - { - name: "sudo in for loop", - command: "for i in 1; do sudo echo $i; done", - shouldError: true, - }, - { - name: "direct curl", - command: "curl http://example.com", - shouldError: true, - }, - { - name: "curl in script", - command: "bash -c 'curl malicious.com'", - shouldError: true, - }, - { - name: "wget command", - command: "wget http://example.com", - shouldError: true, - }, - { - name: "nc command", - command: "nc -l 8080", - shouldError: true, - }, - // Commands that should be allowed - { - name: "simple ls", - command: "ls -la", - shouldError: false, - }, - { - name: "echo command", - command: "echo hello", - shouldError: false, - }, - { - name: "git status", - command: "git status", - shouldError: false, - }, - { - name: "go build", - command: "go build", - shouldError: false, - }, - { - name: "sudo as literal text", - command: "echo 'sudo is just text here'", - shouldError: false, - }, - { - name: "complex allowed command", - command: "find . -name '*.go' | head -10", - shouldError: false, - }, - { - name: "command with environment variables", - command: "FOO=bar go test", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -} - -func TestContainsBannedCommand(t *testing.T) { - // Test the helper functions directly with some edge cases - tests := []struct { - name string - command string - shouldError bool - }{ - { - name: "nested command substitution", - command: "echo $(echo $(sudo id))", - shouldError: true, - }, - { - name: "subshell with banned command", - command: "(sudo ls)", - shouldError: true, - }, - { - name: "case statement with banned command", - command: "case $1 in start) sudo systemctl start service ;; esac", - shouldError: true, - }, - { - name: "while loop with banned command", - command: "while true; do sudo echo test; done", - shouldError: true, - }, - { - name: "function with banned command", - command: "function test() { sudo ls; }", - shouldError: true, - }, - { - name: "complex valid command", - command: "if [ -f file ]; then echo exists; else echo missing; fi", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -}