Revert "chore: implement correct banned commands"

Kujtim Hoxha created

This reverts commit 28be21fd629cfc221127f75bde92de32d09fe85c.

Change summary

internal/llm/tools/bash.go      | 236 ----------------------------------
internal/llm/tools/bash_test.go | 165 ------------------------
2 files changed, 6 insertions(+), 395 deletions(-)

Detailed changes

internal/llm/tools/bash.go 🔗

@@ -10,7 +10,6 @@ import (
 
 	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/charmbracelet/crush/internal/shell"
-	"mvdan.cc/sh/v3/syntax"
 )
 
 type BashParams struct {
@@ -41,235 +40,10 @@ const (
 	BashNoOutput    = "no output"
 )
 
-func containsBannedCommand(node syntax.Node) bool {
-	if node == nil {
-		return false
-	}
-
-	switch n := node.(type) {
-	case *syntax.CallExpr:
-		if len(n.Args) > 0 {
-			cmdName := getWordValue(n.Args[0])
-			for _, banned := range bannedCommands {
-				if strings.EqualFold(cmdName, banned) {
-					return true
-				}
-			}
-		}
-		for _, arg := range n.Args {
-			if containsBannedCommand(arg) {
-				return true
-			}
-		}
-	case *syntax.Word:
-		if checkWordForBannedCommands(n) {
-			return true
-		}
-		for _, part := range n.Parts {
-			if containsBannedCommand(part) {
-				return true
-			}
-		}
-	case *syntax.CmdSubst:
-		for _, stmt := range n.Stmts {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-	case *syntax.Subshell:
-		for _, stmt := range n.Stmts {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-	case *syntax.Stmt:
-		if containsBannedCommand(n.Cmd) {
-			return true
-		}
-		for _, redir := range n.Redirs {
-			if containsBannedCommand(redir) {
-				return true
-			}
-		}
-	case *syntax.BinaryCmd:
-		return containsBannedCommand(n.X) || containsBannedCommand(n.Y)
-	case *syntax.Block:
-		for _, stmt := range n.Stmts {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-	case *syntax.IfClause:
-		for _, stmt := range n.Cond {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-		for _, stmt := range n.Then {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-		if n.Else != nil && containsBannedCommand(n.Else) {
-			return true
-		}
-	case *syntax.WhileClause:
-		for _, stmt := range n.Cond {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-		for _, stmt := range n.Do {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-	case *syntax.ForClause:
-		for _, stmt := range n.Do {
-			if containsBannedCommand(stmt) {
-				return true
-			}
-		}
-		if containsBannedCommand(n.Loop) {
-			return true
-		}
-	case *syntax.CaseClause:
-		for _, item := range n.Items {
-			for _, stmt := range item.Stmts {
-				if containsBannedCommand(stmt) {
-					return true
-				}
-			}
-		}
-	case *syntax.FuncDecl:
-		return containsBannedCommand(n.Body)
-	case *syntax.ArithmExp:
-		return containsBannedCommand(n.X)
-	case *syntax.Redirect:
-		return containsBannedCommand(n.Word)
-	}
-	return false
-}
-
-func checkWordForBannedCommands(word *syntax.Word) bool {
-	if word == nil {
-		return false
-	}
-
-	for _, part := range word.Parts {
-		switch p := part.(type) {
-		case *syntax.SglQuoted:
-			if checkQuotedStringForBannedCommands(p.Value) {
-				return true
-			}
-		case *syntax.DblQuoted:
-			var content strings.Builder
-			for _, qpart := range p.Parts {
-				if lit, ok := qpart.(*syntax.Lit); ok {
-					content.WriteString(lit.Value)
-				}
-			}
-			if checkQuotedStringForBannedCommands(content.String()) {
-				return true
-			}
-		}
-	}
-	return false
-}
-
-func checkQuotedStringForBannedCommands(content string) bool {
-	parser := syntax.NewParser()
-	file, err := parser.Parse(strings.NewReader(content), "")
-	if err != nil {
-		return false
-	}
-
-	if len(file.Stmts) == 0 {
-		return false
-	}
-
-	// Simple heuristic: if it looks like prose rather than commands, don't flag it
-	if len(file.Stmts) == 1 {
-		stmt := file.Stmts[0]
-		if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
-			if len(callExpr.Args) > 2 {
-				allText := true
-				for i, arg := range callExpr.Args {
-					if i == 0 {
-						continue
-					}
-					argStr := getWordValue(arg)
-					if strings.HasPrefix(argStr, "-") {
-						allText = false
-						break
-					}
-				}
-				if allText {
-					return false
-				}
-			}
-		}
-	}
-
-	for _, stmt := range file.Stmts {
-		if containsBannedCommand(stmt) {
-			return true
-		}
-	}
-	return false
-}
-
-func getWordValue(word *syntax.Word) string {
-	if word == nil || len(word.Parts) == 0 {
-		return ""
-	}
-
-	var result strings.Builder
-	for _, part := range word.Parts {
-		switch p := part.(type) {
-		case *syntax.Lit:
-			result.WriteString(p.Value)
-		case *syntax.SglQuoted:
-			result.WriteString(p.Value)
-		case *syntax.DblQuoted:
-			for _, qpart := range p.Parts {
-				if lit, ok := qpart.(*syntax.Lit); ok {
-					result.WriteString(lit.Value)
-				}
-			}
-		}
-	}
-	return result.String()
-}
-
-func validateCommand(command string) error {
-	parser := syntax.NewParser()
-	file, err := parser.Parse(strings.NewReader(command), "")
-	if err != nil {
-		parts := strings.Fields(command)
-		if len(parts) > 0 {
-			baseCmd := parts[0]
-			for _, banned := range bannedCommands {
-				if strings.EqualFold(baseCmd, banned) {
-					return fmt.Errorf("command '%s' is not allowed", baseCmd)
-				}
-			}
-		}
-		return nil
-	}
-
-	for _, stmt := range file.Stmts {
-		if containsBannedCommand(stmt) {
-			return fmt.Errorf("command contains banned operations")
-		}
-	}
-	return nil
-}
-
 var bannedCommands = []string{
 	"alias", "curl", "curlie", "wget", "axel", "aria2c",
 	"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
-	"http-prompt", "chrome", "firefox", "safari", "sudo",
+	"http-prompt", "chrome", "firefox", "safari",
 }
 
 // getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -515,9 +289,11 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		return NewTextErrorResponse("missing command"), nil
 	}
 
-
-	if err := validateCommand(params.Command); err != nil {
-		return NewTextErrorResponse(err.Error()), nil
+	baseCmd := strings.Fields(params.Command)[0]
+	for _, banned := range bannedCommands {
+		if strings.EqualFold(baseCmd, banned) {
+			return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
+		}
 	}
 
 	isSafeReadOnly := false

internal/llm/tools/bash_test.go 🔗

@@ -94,168 +94,3 @@ func TestPlatformSpecificSafeCommands(t *testing.T) {
 		}
 	}
 }
-
-func TestValidateCommand(t *testing.T) {
-	tests := []struct {
-		name        string
-		command     string
-		shouldError bool
-	}{
-		// Commands that should be blocked
-		{
-			name:        "direct sudo",
-			command:     "sudo ls",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in script",
-			command:     "bash -c 'sudo ls'",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in command substitution",
-			command:     "$(sudo whoami)",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in echo command substitution",
-			command:     "echo $(sudo id)",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in command chain",
-			command:     "ls && sudo rm file",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in if statement",
-			command:     "if true; then sudo ls; fi",
-			shouldError: true,
-		},
-		{
-			name:        "sudo in for loop",
-			command:     "for i in 1; do sudo echo $i; done",
-			shouldError: true,
-		},
-		{
-			name:        "direct curl",
-			command:     "curl http://example.com",
-			shouldError: true,
-		},
-		{
-			name:        "curl in script",
-			command:     "bash -c 'curl malicious.com'",
-			shouldError: true,
-		},
-		{
-			name:        "wget command",
-			command:     "wget http://example.com",
-			shouldError: true,
-		},
-		{
-			name:        "nc command",
-			command:     "nc -l 8080",
-			shouldError: true,
-		},
-		// Commands that should be allowed
-		{
-			name:        "simple ls",
-			command:     "ls -la",
-			shouldError: false,
-		},
-		{
-			name:        "echo command",
-			command:     "echo hello",
-			shouldError: false,
-		},
-		{
-			name:        "git status",
-			command:     "git status",
-			shouldError: false,
-		},
-		{
-			name:        "go build",
-			command:     "go build",
-			shouldError: false,
-		},
-		{
-			name:        "sudo as literal text",
-			command:     "echo 'sudo is just text here'",
-			shouldError: false,
-		},
-		{
-			name:        "complex allowed command",
-			command:     "find . -name '*.go' | head -10",
-			shouldError: false,
-		},
-		{
-			name:        "command with environment variables",
-			command:     "FOO=bar go test",
-			shouldError: false,
-		},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			err := validateCommand(tt.command)
-			if tt.shouldError && err == nil {
-				t.Errorf("Expected error for command %q, but got none", tt.command)
-			}
-			if !tt.shouldError && err != nil {
-				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
-			}
-		})
-	}
-}
-
-func TestContainsBannedCommand(t *testing.T) {
-	// Test the helper functions directly with some edge cases
-	tests := []struct {
-		name        string
-		command     string
-		shouldError bool
-	}{
-		{
-			name:        "nested command substitution",
-			command:     "echo $(echo $(sudo id))",
-			shouldError: true,
-		},
-		{
-			name:        "subshell with banned command",
-			command:     "(sudo ls)",
-			shouldError: true,
-		},
-		{
-			name:        "case statement with banned command",
-			command:     "case $1 in start) sudo systemctl start service ;; esac",
-			shouldError: true,
-		},
-		{
-			name:        "while loop with banned command",
-			command:     "while true; do sudo echo test; done",
-			shouldError: true,
-		},
-		{
-			name:        "function with banned command",
-			command:     "function test() { sudo ls; }",
-			shouldError: true,
-		},
-		{
-			name:        "complex valid command",
-			command:     "if [ -f file ]; then echo exists; else echo missing; fi",
-			shouldError: false,
-		},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			err := validateCommand(tt.command)
-			if tt.shouldError && err == nil {
-				t.Errorf("Expected error for command %q, but got none", tt.command)
-			}
-			if !tt.shouldError && err != nil {
-				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
-			}
-		})
-	}
-}