chore: implement correct banned commands

Kujtim Hoxha created

Change summary

internal/llm/tools/bash.go      | 236 ++++++++++++++++++++++++++++++++++
internal/llm/tools/bash_test.go | 165 ++++++++++++++++++++++++
2 files changed, 395 insertions(+), 6 deletions(-)

Detailed changes

internal/llm/tools/bash.go 🔗

@@ -10,6 +10,7 @@ import (
 
 	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/charmbracelet/crush/internal/shell"
+	"mvdan.cc/sh/v3/syntax"
 )
 
 type BashParams struct {
@@ -40,10 +41,235 @@ const (
 	BashNoOutput    = "no output"
 )
 
+func containsBannedCommand(node syntax.Node) bool {
+	if node == nil {
+		return false
+	}
+
+	switch n := node.(type) {
+	case *syntax.CallExpr:
+		if len(n.Args) > 0 {
+			cmdName := getWordValue(n.Args[0])
+			for _, banned := range bannedCommands {
+				if strings.EqualFold(cmdName, banned) {
+					return true
+				}
+			}
+		}
+		for _, arg := range n.Args {
+			if containsBannedCommand(arg) {
+				return true
+			}
+		}
+	case *syntax.Word:
+		if checkWordForBannedCommands(n) {
+			return true
+		}
+		for _, part := range n.Parts {
+			if containsBannedCommand(part) {
+				return true
+			}
+		}
+	case *syntax.CmdSubst:
+		for _, stmt := range n.Stmts {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+	case *syntax.Subshell:
+		for _, stmt := range n.Stmts {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+	case *syntax.Stmt:
+		if containsBannedCommand(n.Cmd) {
+			return true
+		}
+		for _, redir := range n.Redirs {
+			if containsBannedCommand(redir) {
+				return true
+			}
+		}
+	case *syntax.BinaryCmd:
+		return containsBannedCommand(n.X) || containsBannedCommand(n.Y)
+	case *syntax.Block:
+		for _, stmt := range n.Stmts {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+	case *syntax.IfClause:
+		for _, stmt := range n.Cond {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+		for _, stmt := range n.Then {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+		if n.Else != nil && containsBannedCommand(n.Else) {
+			return true
+		}
+	case *syntax.WhileClause:
+		for _, stmt := range n.Cond {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+		for _, stmt := range n.Do {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+	case *syntax.ForClause:
+		for _, stmt := range n.Do {
+			if containsBannedCommand(stmt) {
+				return true
+			}
+		}
+		if containsBannedCommand(n.Loop) {
+			return true
+		}
+	case *syntax.CaseClause:
+		for _, item := range n.Items {
+			for _, stmt := range item.Stmts {
+				if containsBannedCommand(stmt) {
+					return true
+				}
+			}
+		}
+	case *syntax.FuncDecl:
+		return containsBannedCommand(n.Body)
+	case *syntax.ArithmExp:
+		return containsBannedCommand(n.X)
+	case *syntax.Redirect:
+		return containsBannedCommand(n.Word)
+	}
+	return false
+}
+
+func checkWordForBannedCommands(word *syntax.Word) bool {
+	if word == nil {
+		return false
+	}
+
+	for _, part := range word.Parts {
+		switch p := part.(type) {
+		case *syntax.SglQuoted:
+			if checkQuotedStringForBannedCommands(p.Value) {
+				return true
+			}
+		case *syntax.DblQuoted:
+			var content strings.Builder
+			for _, qpart := range p.Parts {
+				if lit, ok := qpart.(*syntax.Lit); ok {
+					content.WriteString(lit.Value)
+				}
+			}
+			if checkQuotedStringForBannedCommands(content.String()) {
+				return true
+			}
+		}
+	}
+	return false
+}
+
+func checkQuotedStringForBannedCommands(content string) bool {
+	parser := syntax.NewParser()
+	file, err := parser.Parse(strings.NewReader(content), "")
+	if err != nil {
+		return false
+	}
+
+	if len(file.Stmts) == 0 {
+		return false
+	}
+
+	// Simple heuristic: if it looks like prose rather than commands, don't flag it
+	if len(file.Stmts) == 1 {
+		stmt := file.Stmts[0]
+		if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
+			if len(callExpr.Args) > 2 {
+				allText := true
+				for i, arg := range callExpr.Args {
+					if i == 0 {
+						continue
+					}
+					argStr := getWordValue(arg)
+					if strings.HasPrefix(argStr, "-") {
+						allText = false
+						break
+					}
+				}
+				if allText {
+					return false
+				}
+			}
+		}
+	}
+
+	for _, stmt := range file.Stmts {
+		if containsBannedCommand(stmt) {
+			return true
+		}
+	}
+	return false
+}
+
+func getWordValue(word *syntax.Word) string {
+	if word == nil || len(word.Parts) == 0 {
+		return ""
+	}
+
+	var result strings.Builder
+	for _, part := range word.Parts {
+		switch p := part.(type) {
+		case *syntax.Lit:
+			result.WriteString(p.Value)
+		case *syntax.SglQuoted:
+			result.WriteString(p.Value)
+		case *syntax.DblQuoted:
+			for _, qpart := range p.Parts {
+				if lit, ok := qpart.(*syntax.Lit); ok {
+					result.WriteString(lit.Value)
+				}
+			}
+		}
+	}
+	return result.String()
+}
+
+func validateCommand(command string) error {
+	parser := syntax.NewParser()
+	file, err := parser.Parse(strings.NewReader(command), "")
+	if err != nil {
+		parts := strings.Fields(command)
+		if len(parts) > 0 {
+			baseCmd := parts[0]
+			for _, banned := range bannedCommands {
+				if strings.EqualFold(baseCmd, banned) {
+					return fmt.Errorf("command '%s' is not allowed", baseCmd)
+				}
+			}
+		}
+		return nil
+	}
+
+	for _, stmt := range file.Stmts {
+		if containsBannedCommand(stmt) {
+			return fmt.Errorf("command contains banned operations")
+		}
+	}
+	return nil
+}
+
 var bannedCommands = []string{
 	"alias", "curl", "curlie", "wget", "axel", "aria2c",
 	"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
-	"http-prompt", "chrome", "firefox", "safari",
+	"http-prompt", "chrome", "firefox", "safari", "sudo",
 }
 
 // getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -289,11 +515,9 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		return NewTextErrorResponse("missing command"), nil
 	}
 
-	baseCmd := strings.Fields(params.Command)[0]
-	for _, banned := range bannedCommands {
-		if strings.EqualFold(baseCmd, banned) {
-			return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
-		}
+
+	if err := validateCommand(params.Command); err != nil {
+		return NewTextErrorResponse(err.Error()), nil
 	}
 
 	isSafeReadOnly := false

internal/llm/tools/bash_test.go 🔗

@@ -94,3 +94,168 @@ func TestPlatformSpecificSafeCommands(t *testing.T) {
 		}
 	}
 }
+
+func TestValidateCommand(t *testing.T) {
+	tests := []struct {
+		name        string
+		command     string
+		shouldError bool
+	}{
+		// Commands that should be blocked
+		{
+			name:        "direct sudo",
+			command:     "sudo ls",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in script",
+			command:     "bash -c 'sudo ls'",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in command substitution",
+			command:     "$(sudo whoami)",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in echo command substitution",
+			command:     "echo $(sudo id)",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in command chain",
+			command:     "ls && sudo rm file",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in if statement",
+			command:     "if true; then sudo ls; fi",
+			shouldError: true,
+		},
+		{
+			name:        "sudo in for loop",
+			command:     "for i in 1; do sudo echo $i; done",
+			shouldError: true,
+		},
+		{
+			name:        "direct curl",
+			command:     "curl http://example.com",
+			shouldError: true,
+		},
+		{
+			name:        "curl in script",
+			command:     "bash -c 'curl malicious.com'",
+			shouldError: true,
+		},
+		{
+			name:        "wget command",
+			command:     "wget http://example.com",
+			shouldError: true,
+		},
+		{
+			name:        "nc command",
+			command:     "nc -l 8080",
+			shouldError: true,
+		},
+		// Commands that should be allowed
+		{
+			name:        "simple ls",
+			command:     "ls -la",
+			shouldError: false,
+		},
+		{
+			name:        "echo command",
+			command:     "echo hello",
+			shouldError: false,
+		},
+		{
+			name:        "git status",
+			command:     "git status",
+			shouldError: false,
+		},
+		{
+			name:        "go build",
+			command:     "go build",
+			shouldError: false,
+		},
+		{
+			name:        "sudo as literal text",
+			command:     "echo 'sudo is just text here'",
+			shouldError: false,
+		},
+		{
+			name:        "complex allowed command",
+			command:     "find . -name '*.go' | head -10",
+			shouldError: false,
+		},
+		{
+			name:        "command with environment variables",
+			command:     "FOO=bar go test",
+			shouldError: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := validateCommand(tt.command)
+			if tt.shouldError && err == nil {
+				t.Errorf("Expected error for command %q, but got none", tt.command)
+			}
+			if !tt.shouldError && err != nil {
+				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
+			}
+		})
+	}
+}
+
+func TestContainsBannedCommand(t *testing.T) {
+	// Test the helper functions directly with some edge cases
+	tests := []struct {
+		name        string
+		command     string
+		shouldError bool
+	}{
+		{
+			name:        "nested command substitution",
+			command:     "echo $(echo $(sudo id))",
+			shouldError: true,
+		},
+		{
+			name:        "subshell with banned command",
+			command:     "(sudo ls)",
+			shouldError: true,
+		},
+		{
+			name:        "case statement with banned command",
+			command:     "case $1 in start) sudo systemctl start service ;; esac",
+			shouldError: true,
+		},
+		{
+			name:        "while loop with banned command",
+			command:     "while true; do sudo echo test; done",
+			shouldError: true,
+		},
+		{
+			name:        "function with banned command",
+			command:     "function test() { sudo ls; }",
+			shouldError: true,
+		},
+		{
+			name:        "complex valid command",
+			command:     "if [ -f file ]; then echo exists; else echo missing; fi",
+			shouldError: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := validateCommand(tt.command)
+			if tt.shouldError && err == nil {
+				t.Errorf("Expected error for command %q, but got none", tt.command)
+			}
+			if !tt.shouldError && err != nil {
+				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
+			}
+		})
+	}
+}