From 28be21fd629cfc221127f75bde92de32d09fe85c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 10 Jul 2025 23:42:08 +0200 Subject: [PATCH] chore: implement correct banned commands --- internal/llm/tools/bash.go | 236 +++++++++++++++++++++++++++++++- internal/llm/tools/bash_test.go | 165 ++++++++++++++++++++++ 2 files changed, 395 insertions(+), 6 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..ea736e81bc05a67d1b42bb5927f537b06f4ada5f 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" + "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -40,10 +41,235 @@ const ( BashNoOutput = "no output" ) +func containsBannedCommand(node syntax.Node) bool { + if node == nil { + return false + } + + switch n := node.(type) { + case *syntax.CallExpr: + if len(n.Args) > 0 { + cmdName := getWordValue(n.Args[0]) + for _, banned := range bannedCommands { + if strings.EqualFold(cmdName, banned) { + return true + } + } + } + for _, arg := range n.Args { + if containsBannedCommand(arg) { + return true + } + } + case *syntax.Word: + if checkWordForBannedCommands(n) { + return true + } + for _, part := range n.Parts { + if containsBannedCommand(part) { + return true + } + } + case *syntax.CmdSubst: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Subshell: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Stmt: + if containsBannedCommand(n.Cmd) { + return true + } + for _, redir := range n.Redirs { + if containsBannedCommand(redir) { + return true + } + } + case *syntax.BinaryCmd: + return containsBannedCommand(n.X) || containsBannedCommand(n.Y) + case *syntax.Block: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.IfClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Then { + if containsBannedCommand(stmt) { + return true + } + } + if n.Else != nil && containsBannedCommand(n.Else) { + return true + } + case *syntax.WhileClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.ForClause: + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + if containsBannedCommand(n.Loop) { + return true + } + case *syntax.CaseClause: + for _, item := range n.Items { + for _, stmt := range item.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + } + case *syntax.FuncDecl: + return containsBannedCommand(n.Body) + case *syntax.ArithmExp: + return containsBannedCommand(n.X) + case *syntax.Redirect: + return containsBannedCommand(n.Word) + } + return false +} + +func checkWordForBannedCommands(word *syntax.Word) bool { + if word == nil { + return false + } + + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.SglQuoted: + if checkQuotedStringForBannedCommands(p.Value) { + return true + } + case *syntax.DblQuoted: + var content strings.Builder + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + content.WriteString(lit.Value) + } + } + if checkQuotedStringForBannedCommands(content.String()) { + return true + } + } + } + return false +} + +func checkQuotedStringForBannedCommands(content string) bool { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(content), "") + if err != nil { + return false + } + + if len(file.Stmts) == 0 { + return false + } + + // Simple heuristic: if it looks like prose rather than commands, don't flag it + if len(file.Stmts) == 1 { + stmt := file.Stmts[0] + if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { + if len(callExpr.Args) > 2 { + allText := true + for i, arg := range callExpr.Args { + if i == 0 { + continue + } + argStr := getWordValue(arg) + if strings.HasPrefix(argStr, "-") { + allText = false + break + } + } + if allText { + return false + } + } + } + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + return false +} + +func getWordValue(word *syntax.Word) string { + if word == nil || len(word.Parts) == 0 { + return "" + } + + var result strings.Builder + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.Lit: + result.WriteString(p.Value) + case *syntax.SglQuoted: + result.WriteString(p.Value) + case *syntax.DblQuoted: + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + result.WriteString(lit.Value) + } + } + } + } + return result.String() +} + +func validateCommand(command string) error { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(command), "") + if err != nil { + parts := strings.Fields(command) + if len(parts) > 0 { + baseCmd := parts[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return fmt.Errorf("command '%s' is not allowed", baseCmd) + } + } + } + return nil + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return fmt.Errorf("command contains banned operations") + } + } + return nil +} + var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -289,11 +515,9 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } + + if err := validateCommand(params.Command); err != nil { + return NewTextErrorResponse(err.Error()), nil } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index a810002749408af2bb89cb958b5999dc2da3bcb3..768a3738dd2fe162838f32ee498ac14ed2ba9eee 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,3 +94,168 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } + +func TestValidateCommand(t *testing.T) { + tests := []struct { + name string + command string + shouldError bool + }{ + // Commands that should be blocked + { + name: "direct sudo", + command: "sudo ls", + shouldError: true, + }, + { + name: "sudo in script", + command: "bash -c 'sudo ls'", + shouldError: true, + }, + { + name: "sudo in command substitution", + command: "$(sudo whoami)", + shouldError: true, + }, + { + name: "sudo in echo command substitution", + command: "echo $(sudo id)", + shouldError: true, + }, + { + name: "sudo in command chain", + command: "ls && sudo rm file", + shouldError: true, + }, + { + name: "sudo in if statement", + command: "if true; then sudo ls; fi", + shouldError: true, + }, + { + name: "sudo in for loop", + command: "for i in 1; do sudo echo $i; done", + shouldError: true, + }, + { + name: "direct curl", + command: "curl http://example.com", + shouldError: true, + }, + { + name: "curl in script", + command: "bash -c 'curl malicious.com'", + shouldError: true, + }, + { + name: "wget command", + command: "wget http://example.com", + shouldError: true, + }, + { + name: "nc command", + command: "nc -l 8080", + shouldError: true, + }, + // Commands that should be allowed + { + name: "simple ls", + command: "ls -la", + shouldError: false, + }, + { + name: "echo command", + command: "echo hello", + shouldError: false, + }, + { + name: "git status", + command: "git status", + shouldError: false, + }, + { + name: "go build", + command: "go build", + shouldError: false, + }, + { + name: "sudo as literal text", + command: "echo 'sudo is just text here'", + shouldError: false, + }, + { + name: "complex allowed command", + command: "find . -name '*.go' | head -10", + shouldError: false, + }, + { + name: "command with environment variables", + command: "FOO=bar go test", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +} + +func TestContainsBannedCommand(t *testing.T) { + // Test the helper functions directly with some edge cases + tests := []struct { + name string + command string + shouldError bool + }{ + { + name: "nested command substitution", + command: "echo $(echo $(sudo id))", + shouldError: true, + }, + { + name: "subshell with banned command", + command: "(sudo ls)", + shouldError: true, + }, + { + name: "case statement with banned command", + command: "case $1 in start) sudo systemctl start service ;; esac", + shouldError: true, + }, + { + name: "while loop with banned command", + command: "while true; do sudo echo test; done", + shouldError: true, + }, + { + name: "function with banned command", + command: "function test() { sudo ls; }", + shouldError: true, + }, + { + name: "complex valid command", + command: "if [ -f file ]; then echo exists; else echo missing; fi", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +}