@@ -10,6 +10,7 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/shell"
+ "mvdan.cc/sh/v3/syntax"
)
type BashParams struct {
@@ -40,10 +41,235 @@ const (
BashNoOutput = "no output"
)
+func containsBannedCommand(node syntax.Node) bool {
+ if node == nil {
+ return false
+ }
+
+ switch n := node.(type) {
+ case *syntax.CallExpr:
+ if len(n.Args) > 0 {
+ cmdName := getWordValue(n.Args[0])
+ for _, banned := range bannedCommands {
+ if strings.EqualFold(cmdName, banned) {
+ return true
+ }
+ }
+ }
+ for _, arg := range n.Args {
+ if containsBannedCommand(arg) {
+ return true
+ }
+ }
+ case *syntax.Word:
+ if checkWordForBannedCommands(n) {
+ return true
+ }
+ for _, part := range n.Parts {
+ if containsBannedCommand(part) {
+ return true
+ }
+ }
+ case *syntax.CmdSubst:
+ for _, stmt := range n.Stmts {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ case *syntax.Subshell:
+ for _, stmt := range n.Stmts {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ case *syntax.Stmt:
+ if containsBannedCommand(n.Cmd) {
+ return true
+ }
+ for _, redir := range n.Redirs {
+ if containsBannedCommand(redir) {
+ return true
+ }
+ }
+ case *syntax.BinaryCmd:
+ return containsBannedCommand(n.X) || containsBannedCommand(n.Y)
+ case *syntax.Block:
+ for _, stmt := range n.Stmts {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ case *syntax.IfClause:
+ for _, stmt := range n.Cond {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ for _, stmt := range n.Then {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ if n.Else != nil && containsBannedCommand(n.Else) {
+ return true
+ }
+ case *syntax.WhileClause:
+ for _, stmt := range n.Cond {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ for _, stmt := range n.Do {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ case *syntax.ForClause:
+ for _, stmt := range n.Do {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ if containsBannedCommand(n.Loop) {
+ return true
+ }
+ case *syntax.CaseClause:
+ for _, item := range n.Items {
+ for _, stmt := range item.Stmts {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ }
+ case *syntax.FuncDecl:
+ return containsBannedCommand(n.Body)
+ case *syntax.ArithmExp:
+ return containsBannedCommand(n.X)
+ case *syntax.Redirect:
+ return containsBannedCommand(n.Word)
+ }
+ return false
+}
+
+func checkWordForBannedCommands(word *syntax.Word) bool {
+ if word == nil {
+ return false
+ }
+
+ for _, part := range word.Parts {
+ switch p := part.(type) {
+ case *syntax.SglQuoted:
+ if checkQuotedStringForBannedCommands(p.Value) {
+ return true
+ }
+ case *syntax.DblQuoted:
+ var content strings.Builder
+ for _, qpart := range p.Parts {
+ if lit, ok := qpart.(*syntax.Lit); ok {
+ content.WriteString(lit.Value)
+ }
+ }
+ if checkQuotedStringForBannedCommands(content.String()) {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func checkQuotedStringForBannedCommands(content string) bool {
+ parser := syntax.NewParser()
+ file, err := parser.Parse(strings.NewReader(content), "")
+ if err != nil {
+ return false
+ }
+
+ if len(file.Stmts) == 0 {
+ return false
+ }
+
+ // Simple heuristic: if it looks like prose rather than commands, don't flag it
+ if len(file.Stmts) == 1 {
+ stmt := file.Stmts[0]
+ if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
+ if len(callExpr.Args) > 2 {
+ allText := true
+ for i, arg := range callExpr.Args {
+ if i == 0 {
+ continue
+ }
+ argStr := getWordValue(arg)
+ if strings.HasPrefix(argStr, "-") {
+ allText = false
+ break
+ }
+ }
+ if allText {
+ return false
+ }
+ }
+ }
+ }
+
+ for _, stmt := range file.Stmts {
+ if containsBannedCommand(stmt) {
+ return true
+ }
+ }
+ return false
+}
+
+func getWordValue(word *syntax.Word) string {
+ if word == nil || len(word.Parts) == 0 {
+ return ""
+ }
+
+ var result strings.Builder
+ for _, part := range word.Parts {
+ switch p := part.(type) {
+ case *syntax.Lit:
+ result.WriteString(p.Value)
+ case *syntax.SglQuoted:
+ result.WriteString(p.Value)
+ case *syntax.DblQuoted:
+ for _, qpart := range p.Parts {
+ if lit, ok := qpart.(*syntax.Lit); ok {
+ result.WriteString(lit.Value)
+ }
+ }
+ }
+ }
+ return result.String()
+}
+
+func validateCommand(command string) error {
+ parser := syntax.NewParser()
+ file, err := parser.Parse(strings.NewReader(command), "")
+ if err != nil {
+ parts := strings.Fields(command)
+ if len(parts) > 0 {
+ baseCmd := parts[0]
+ for _, banned := range bannedCommands {
+ if strings.EqualFold(baseCmd, banned) {
+ return fmt.Errorf("command '%s' is not allowed", baseCmd)
+ }
+ }
+ }
+ return nil
+ }
+
+ for _, stmt := range file.Stmts {
+ if containsBannedCommand(stmt) {
+ return fmt.Errorf("command contains banned operations")
+ }
+ }
+ return nil
+}
+
var bannedCommands = []string{
"alias", "curl", "curlie", "wget", "axel", "aria2c",
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
- "http-prompt", "chrome", "firefox", "safari",
+ "http-prompt", "chrome", "firefox", "safari", "sudo",
}
// getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -289,11 +515,9 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse("missing command"), nil
}
- baseCmd := strings.Fields(params.Command)[0]
- for _, banned := range bannedCommands {
- if strings.EqualFold(baseCmd, banned) {
- return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
- }
+
+ if err := validateCommand(params.Command); err != nil {
+ return NewTextErrorResponse(err.Error()), nil
}
isSafeReadOnly := false
@@ -94,3 +94,168 @@ func TestPlatformSpecificSafeCommands(t *testing.T) {
}
}
}
+
+func TestValidateCommand(t *testing.T) {
+ tests := []struct {
+ name string
+ command string
+ shouldError bool
+ }{
+ // Commands that should be blocked
+ {
+ name: "direct sudo",
+ command: "sudo ls",
+ shouldError: true,
+ },
+ {
+ name: "sudo in script",
+ command: "bash -c 'sudo ls'",
+ shouldError: true,
+ },
+ {
+ name: "sudo in command substitution",
+ command: "$(sudo whoami)",
+ shouldError: true,
+ },
+ {
+ name: "sudo in echo command substitution",
+ command: "echo $(sudo id)",
+ shouldError: true,
+ },
+ {
+ name: "sudo in command chain",
+ command: "ls && sudo rm file",
+ shouldError: true,
+ },
+ {
+ name: "sudo in if statement",
+ command: "if true; then sudo ls; fi",
+ shouldError: true,
+ },
+ {
+ name: "sudo in for loop",
+ command: "for i in 1; do sudo echo $i; done",
+ shouldError: true,
+ },
+ {
+ name: "direct curl",
+ command: "curl http://example.com",
+ shouldError: true,
+ },
+ {
+ name: "curl in script",
+ command: "bash -c 'curl malicious.com'",
+ shouldError: true,
+ },
+ {
+ name: "wget command",
+ command: "wget http://example.com",
+ shouldError: true,
+ },
+ {
+ name: "nc command",
+ command: "nc -l 8080",
+ shouldError: true,
+ },
+ // Commands that should be allowed
+ {
+ name: "simple ls",
+ command: "ls -la",
+ shouldError: false,
+ },
+ {
+ name: "echo command",
+ command: "echo hello",
+ shouldError: false,
+ },
+ {
+ name: "git status",
+ command: "git status",
+ shouldError: false,
+ },
+ {
+ name: "go build",
+ command: "go build",
+ shouldError: false,
+ },
+ {
+ name: "sudo as literal text",
+ command: "echo 'sudo is just text here'",
+ shouldError: false,
+ },
+ {
+ name: "complex allowed command",
+ command: "find . -name '*.go' | head -10",
+ shouldError: false,
+ },
+ {
+ name: "command with environment variables",
+ command: "FOO=bar go test",
+ shouldError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateCommand(tt.command)
+ if tt.shouldError && err == nil {
+ t.Errorf("Expected error for command %q, but got none", tt.command)
+ }
+ if !tt.shouldError && err != nil {
+ t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
+ }
+ })
+ }
+}
+
+func TestContainsBannedCommand(t *testing.T) {
+ // Test the helper functions directly with some edge cases
+ tests := []struct {
+ name string
+ command string
+ shouldError bool
+ }{
+ {
+ name: "nested command substitution",
+ command: "echo $(echo $(sudo id))",
+ shouldError: true,
+ },
+ {
+ name: "subshell with banned command",
+ command: "(sudo ls)",
+ shouldError: true,
+ },
+ {
+ name: "case statement with banned command",
+ command: "case $1 in start) sudo systemctl start service ;; esac",
+ shouldError: true,
+ },
+ {
+ name: "while loop with banned command",
+ command: "while true; do sudo echo test; done",
+ shouldError: true,
+ },
+ {
+ name: "function with banned command",
+ command: "function test() { sudo ls; }",
+ shouldError: true,
+ },
+ {
+ name: "complex valid command",
+ command: "if [ -f file ]; then echo exists; else echo missing; fi",
+ shouldError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateCommand(tt.command)
+ if tt.shouldError && err == nil {
+ t.Errorf("Expected error for command %q, but got none", tt.command)
+ }
+ if !tt.shouldError && err != nil {
+ t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
+ }
+ })
+ }
+}