@@ -10,7 +10,6 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/shell"
- "mvdan.cc/sh/v3/syntax"
)
type BashParams struct {
@@ -41,235 +40,10 @@ const (
BashNoOutput = "no output"
)
-func containsBannedCommand(node syntax.Node) bool {
- if node == nil {
- return false
- }
-
- switch n := node.(type) {
- case *syntax.CallExpr:
- if len(n.Args) > 0 {
- cmdName := getWordValue(n.Args[0])
- for _, banned := range bannedCommands {
- if strings.EqualFold(cmdName, banned) {
- return true
- }
- }
- }
- for _, arg := range n.Args {
- if containsBannedCommand(arg) {
- return true
- }
- }
- case *syntax.Word:
- if checkWordForBannedCommands(n) {
- return true
- }
- for _, part := range n.Parts {
- if containsBannedCommand(part) {
- return true
- }
- }
- case *syntax.CmdSubst:
- for _, stmt := range n.Stmts {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- case *syntax.Subshell:
- for _, stmt := range n.Stmts {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- case *syntax.Stmt:
- if containsBannedCommand(n.Cmd) {
- return true
- }
- for _, redir := range n.Redirs {
- if containsBannedCommand(redir) {
- return true
- }
- }
- case *syntax.BinaryCmd:
- return containsBannedCommand(n.X) || containsBannedCommand(n.Y)
- case *syntax.Block:
- for _, stmt := range n.Stmts {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- case *syntax.IfClause:
- for _, stmt := range n.Cond {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- for _, stmt := range n.Then {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- if n.Else != nil && containsBannedCommand(n.Else) {
- return true
- }
- case *syntax.WhileClause:
- for _, stmt := range n.Cond {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- for _, stmt := range n.Do {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- case *syntax.ForClause:
- for _, stmt := range n.Do {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- if containsBannedCommand(n.Loop) {
- return true
- }
- case *syntax.CaseClause:
- for _, item := range n.Items {
- for _, stmt := range item.Stmts {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- }
- case *syntax.FuncDecl:
- return containsBannedCommand(n.Body)
- case *syntax.ArithmExp:
- return containsBannedCommand(n.X)
- case *syntax.Redirect:
- return containsBannedCommand(n.Word)
- }
- return false
-}
-
-func checkWordForBannedCommands(word *syntax.Word) bool {
- if word == nil {
- return false
- }
-
- for _, part := range word.Parts {
- switch p := part.(type) {
- case *syntax.SglQuoted:
- if checkQuotedStringForBannedCommands(p.Value) {
- return true
- }
- case *syntax.DblQuoted:
- var content strings.Builder
- for _, qpart := range p.Parts {
- if lit, ok := qpart.(*syntax.Lit); ok {
- content.WriteString(lit.Value)
- }
- }
- if checkQuotedStringForBannedCommands(content.String()) {
- return true
- }
- }
- }
- return false
-}
-
-func checkQuotedStringForBannedCommands(content string) bool {
- parser := syntax.NewParser()
- file, err := parser.Parse(strings.NewReader(content), "")
- if err != nil {
- return false
- }
-
- if len(file.Stmts) == 0 {
- return false
- }
-
- // Simple heuristic: if it looks like prose rather than commands, don't flag it
- if len(file.Stmts) == 1 {
- stmt := file.Stmts[0]
- if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
- if len(callExpr.Args) > 2 {
- allText := true
- for i, arg := range callExpr.Args {
- if i == 0 {
- continue
- }
- argStr := getWordValue(arg)
- if strings.HasPrefix(argStr, "-") {
- allText = false
- break
- }
- }
- if allText {
- return false
- }
- }
- }
- }
-
- for _, stmt := range file.Stmts {
- if containsBannedCommand(stmt) {
- return true
- }
- }
- return false
-}
-
-func getWordValue(word *syntax.Word) string {
- if word == nil || len(word.Parts) == 0 {
- return ""
- }
-
- var result strings.Builder
- for _, part := range word.Parts {
- switch p := part.(type) {
- case *syntax.Lit:
- result.WriteString(p.Value)
- case *syntax.SglQuoted:
- result.WriteString(p.Value)
- case *syntax.DblQuoted:
- for _, qpart := range p.Parts {
- if lit, ok := qpart.(*syntax.Lit); ok {
- result.WriteString(lit.Value)
- }
- }
- }
- }
- return result.String()
-}
-
-func validateCommand(command string) error {
- parser := syntax.NewParser()
- file, err := parser.Parse(strings.NewReader(command), "")
- if err != nil {
- parts := strings.Fields(command)
- if len(parts) > 0 {
- baseCmd := parts[0]
- for _, banned := range bannedCommands {
- if strings.EqualFold(baseCmd, banned) {
- return fmt.Errorf("command '%s' is not allowed", baseCmd)
- }
- }
- }
- return nil
- }
-
- for _, stmt := range file.Stmts {
- if containsBannedCommand(stmt) {
- return fmt.Errorf("command contains banned operations")
- }
- }
- return nil
-}
-
var bannedCommands = []string{
"alias", "curl", "curlie", "wget", "axel", "aria2c",
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
- "http-prompt", "chrome", "firefox", "safari", "sudo",
+ "http-prompt", "chrome", "firefox", "safari",
}
// getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -515,9 +289,11 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse("missing command"), nil
}
-
- if err := validateCommand(params.Command); err != nil {
- return NewTextErrorResponse(err.Error()), nil
+ baseCmd := strings.Fields(params.Command)[0]
+ for _, banned := range bannedCommands {
+ if strings.EqualFold(baseCmd, banned) {
+ return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
+ }
}
isSafeReadOnly := false
@@ -94,168 +94,3 @@ func TestPlatformSpecificSafeCommands(t *testing.T) {
}
}
}
-
-func TestValidateCommand(t *testing.T) {
- tests := []struct {
- name string
- command string
- shouldError bool
- }{
- // Commands that should be blocked
- {
- name: "direct sudo",
- command: "sudo ls",
- shouldError: true,
- },
- {
- name: "sudo in script",
- command: "bash -c 'sudo ls'",
- shouldError: true,
- },
- {
- name: "sudo in command substitution",
- command: "$(sudo whoami)",
- shouldError: true,
- },
- {
- name: "sudo in echo command substitution",
- command: "echo $(sudo id)",
- shouldError: true,
- },
- {
- name: "sudo in command chain",
- command: "ls && sudo rm file",
- shouldError: true,
- },
- {
- name: "sudo in if statement",
- command: "if true; then sudo ls; fi",
- shouldError: true,
- },
- {
- name: "sudo in for loop",
- command: "for i in 1; do sudo echo $i; done",
- shouldError: true,
- },
- {
- name: "direct curl",
- command: "curl http://example.com",
- shouldError: true,
- },
- {
- name: "curl in script",
- command: "bash -c 'curl malicious.com'",
- shouldError: true,
- },
- {
- name: "wget command",
- command: "wget http://example.com",
- shouldError: true,
- },
- {
- name: "nc command",
- command: "nc -l 8080",
- shouldError: true,
- },
- // Commands that should be allowed
- {
- name: "simple ls",
- command: "ls -la",
- shouldError: false,
- },
- {
- name: "echo command",
- command: "echo hello",
- shouldError: false,
- },
- {
- name: "git status",
- command: "git status",
- shouldError: false,
- },
- {
- name: "go build",
- command: "go build",
- shouldError: false,
- },
- {
- name: "sudo as literal text",
- command: "echo 'sudo is just text here'",
- shouldError: false,
- },
- {
- name: "complex allowed command",
- command: "find . -name '*.go' | head -10",
- shouldError: false,
- },
- {
- name: "command with environment variables",
- command: "FOO=bar go test",
- shouldError: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := validateCommand(tt.command)
- if tt.shouldError && err == nil {
- t.Errorf("Expected error for command %q, but got none", tt.command)
- }
- if !tt.shouldError && err != nil {
- t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
- }
- })
- }
-}
-
-func TestContainsBannedCommand(t *testing.T) {
- // Test the helper functions directly with some edge cases
- tests := []struct {
- name string
- command string
- shouldError bool
- }{
- {
- name: "nested command substitution",
- command: "echo $(echo $(sudo id))",
- shouldError: true,
- },
- {
- name: "subshell with banned command",
- command: "(sudo ls)",
- shouldError: true,
- },
- {
- name: "case statement with banned command",
- command: "case $1 in start) sudo systemctl start service ;; esac",
- shouldError: true,
- },
- {
- name: "while loop with banned command",
- command: "while true; do sudo echo test; done",
- shouldError: true,
- },
- {
- name: "function with banned command",
- command: "function test() { sudo ls; }",
- shouldError: true,
- },
- {
- name: "complex valid command",
- command: "if [ -f file ]; then echo exists; else echo missing; fi",
- shouldError: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := validateCommand(tt.command)
- if tt.shouldError && err == nil {
- t.Errorf("Expected error for command %q, but got none", tt.command)
- }
- if !tt.shouldError && err != nil {
- t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
- }
- })
- }
-}