From 28be21fd629cfc221127f75bde92de32d09fe85c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 10 Jul 2025 23:42:08 +0200 Subject: [PATCH 1/6] chore: implement correct banned commands --- internal/llm/tools/bash.go | 236 +++++++++++++++++++++++++++++++- internal/llm/tools/bash_test.go | 165 ++++++++++++++++++++++ 2 files changed, 395 insertions(+), 6 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..ea736e81bc05a67d1b42bb5927f537b06f4ada5f 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" + "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -40,10 +41,235 @@ const ( BashNoOutput = "no output" ) +func containsBannedCommand(node syntax.Node) bool { + if node == nil { + return false + } + + switch n := node.(type) { + case *syntax.CallExpr: + if len(n.Args) > 0 { + cmdName := getWordValue(n.Args[0]) + for _, banned := range bannedCommands { + if strings.EqualFold(cmdName, banned) { + return true + } + } + } + for _, arg := range n.Args { + if containsBannedCommand(arg) { + return true + } + } + case *syntax.Word: + if checkWordForBannedCommands(n) { + return true + } + for _, part := range n.Parts { + if containsBannedCommand(part) { + return true + } + } + case *syntax.CmdSubst: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Subshell: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Stmt: + if containsBannedCommand(n.Cmd) { + return true + } + for _, redir := range n.Redirs { + if containsBannedCommand(redir) { + return true + } + } + case *syntax.BinaryCmd: + return containsBannedCommand(n.X) || containsBannedCommand(n.Y) + case *syntax.Block: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.IfClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Then { + if containsBannedCommand(stmt) { + return true + } + } + if n.Else != nil && containsBannedCommand(n.Else) { + return true + } + case *syntax.WhileClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.ForClause: + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + if containsBannedCommand(n.Loop) { + return true + } + case *syntax.CaseClause: + for _, item := range n.Items { + for _, stmt := range item.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + } + case *syntax.FuncDecl: + return containsBannedCommand(n.Body) + case *syntax.ArithmExp: + return containsBannedCommand(n.X) + case *syntax.Redirect: + return containsBannedCommand(n.Word) + } + return false +} + +func checkWordForBannedCommands(word *syntax.Word) bool { + if word == nil { + return false + } + + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.SglQuoted: + if checkQuotedStringForBannedCommands(p.Value) { + return true + } + case *syntax.DblQuoted: + var content strings.Builder + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + content.WriteString(lit.Value) + } + } + if checkQuotedStringForBannedCommands(content.String()) { + return true + } + } + } + return false +} + +func checkQuotedStringForBannedCommands(content string) bool { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(content), "") + if err != nil { + return false + } + + if len(file.Stmts) == 0 { + return false + } + + // Simple heuristic: if it looks like prose rather than commands, don't flag it + if len(file.Stmts) == 1 { + stmt := file.Stmts[0] + if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { + if len(callExpr.Args) > 2 { + allText := true + for i, arg := range callExpr.Args { + if i == 0 { + continue + } + argStr := getWordValue(arg) + if strings.HasPrefix(argStr, "-") { + allText = false + break + } + } + if allText { + return false + } + } + } + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + return false +} + +func getWordValue(word *syntax.Word) string { + if word == nil || len(word.Parts) == 0 { + return "" + } + + var result strings.Builder + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.Lit: + result.WriteString(p.Value) + case *syntax.SglQuoted: + result.WriteString(p.Value) + case *syntax.DblQuoted: + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + result.WriteString(lit.Value) + } + } + } + } + return result.String() +} + +func validateCommand(command string) error { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(command), "") + if err != nil { + parts := strings.Fields(command) + if len(parts) > 0 { + baseCmd := parts[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return fmt.Errorf("command '%s' is not allowed", baseCmd) + } + } + } + return nil + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return fmt.Errorf("command contains banned operations") + } + } + return nil +} + var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -289,11 +515,9 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } + + if err := validateCommand(params.Command); err != nil { + return NewTextErrorResponse(err.Error()), nil } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index a810002749408af2bb89cb958b5999dc2da3bcb3..768a3738dd2fe162838f32ee498ac14ed2ba9eee 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,3 +94,168 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } + +func TestValidateCommand(t *testing.T) { + tests := []struct { + name string + command string + shouldError bool + }{ + // Commands that should be blocked + { + name: "direct sudo", + command: "sudo ls", + shouldError: true, + }, + { + name: "sudo in script", + command: "bash -c 'sudo ls'", + shouldError: true, + }, + { + name: "sudo in command substitution", + command: "$(sudo whoami)", + shouldError: true, + }, + { + name: "sudo in echo command substitution", + command: "echo $(sudo id)", + shouldError: true, + }, + { + name: "sudo in command chain", + command: "ls && sudo rm file", + shouldError: true, + }, + { + name: "sudo in if statement", + command: "if true; then sudo ls; fi", + shouldError: true, + }, + { + name: "sudo in for loop", + command: "for i in 1; do sudo echo $i; done", + shouldError: true, + }, + { + name: "direct curl", + command: "curl http://example.com", + shouldError: true, + }, + { + name: "curl in script", + command: "bash -c 'curl malicious.com'", + shouldError: true, + }, + { + name: "wget command", + command: "wget http://example.com", + shouldError: true, + }, + { + name: "nc command", + command: "nc -l 8080", + shouldError: true, + }, + // Commands that should be allowed + { + name: "simple ls", + command: "ls -la", + shouldError: false, + }, + { + name: "echo command", + command: "echo hello", + shouldError: false, + }, + { + name: "git status", + command: "git status", + shouldError: false, + }, + { + name: "go build", + command: "go build", + shouldError: false, + }, + { + name: "sudo as literal text", + command: "echo 'sudo is just text here'", + shouldError: false, + }, + { + name: "complex allowed command", + command: "find . -name '*.go' | head -10", + shouldError: false, + }, + { + name: "command with environment variables", + command: "FOO=bar go test", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +} + +func TestContainsBannedCommand(t *testing.T) { + // Test the helper functions directly with some edge cases + tests := []struct { + name string + command string + shouldError bool + }{ + { + name: "nested command substitution", + command: "echo $(echo $(sudo id))", + shouldError: true, + }, + { + name: "subshell with banned command", + command: "(sudo ls)", + shouldError: true, + }, + { + name: "case statement with banned command", + command: "case $1 in start) sudo systemctl start service ;; esac", + shouldError: true, + }, + { + name: "while loop with banned command", + command: "while true; do sudo echo test; done", + shouldError: true, + }, + { + name: "function with banned command", + command: "function test() { sudo ls; }", + shouldError: true, + }, + { + name: "complex valid command", + command: "if [ -f file ]; then echo exists; else echo missing; fi", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +} From de3b81532ae3c21551ca84cd3004acc54291739c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:00:44 +0200 Subject: [PATCH 2/6] Revert "chore: implement correct banned commands" This reverts commit 28be21fd629cfc221127f75bde92de32d09fe85c. --- internal/llm/tools/bash.go | 236 +------------------------------- internal/llm/tools/bash_test.go | 165 ---------------------- 2 files changed, 6 insertions(+), 395 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index ea736e81bc05a67d1b42bb5927f537b06f4ada5f..0a10568a39315f6c4077385b8ca83f6b3e52691c 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,7 +10,6 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" - "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -41,235 +40,10 @@ const ( BashNoOutput = "no output" ) -func containsBannedCommand(node syntax.Node) bool { - if node == nil { - return false - } - - switch n := node.(type) { - case *syntax.CallExpr: - if len(n.Args) > 0 { - cmdName := getWordValue(n.Args[0]) - for _, banned := range bannedCommands { - if strings.EqualFold(cmdName, banned) { - return true - } - } - } - for _, arg := range n.Args { - if containsBannedCommand(arg) { - return true - } - } - case *syntax.Word: - if checkWordForBannedCommands(n) { - return true - } - for _, part := range n.Parts { - if containsBannedCommand(part) { - return true - } - } - case *syntax.CmdSubst: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Subshell: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Stmt: - if containsBannedCommand(n.Cmd) { - return true - } - for _, redir := range n.Redirs { - if containsBannedCommand(redir) { - return true - } - } - case *syntax.BinaryCmd: - return containsBannedCommand(n.X) || containsBannedCommand(n.Y) - case *syntax.Block: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.IfClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Then { - if containsBannedCommand(stmt) { - return true - } - } - if n.Else != nil && containsBannedCommand(n.Else) { - return true - } - case *syntax.WhileClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.ForClause: - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - if containsBannedCommand(n.Loop) { - return true - } - case *syntax.CaseClause: - for _, item := range n.Items { - for _, stmt := range item.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - } - case *syntax.FuncDecl: - return containsBannedCommand(n.Body) - case *syntax.ArithmExp: - return containsBannedCommand(n.X) - case *syntax.Redirect: - return containsBannedCommand(n.Word) - } - return false -} - -func checkWordForBannedCommands(word *syntax.Word) bool { - if word == nil { - return false - } - - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.SglQuoted: - if checkQuotedStringForBannedCommands(p.Value) { - return true - } - case *syntax.DblQuoted: - var content strings.Builder - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - content.WriteString(lit.Value) - } - } - if checkQuotedStringForBannedCommands(content.String()) { - return true - } - } - } - return false -} - -func checkQuotedStringForBannedCommands(content string) bool { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(content), "") - if err != nil { - return false - } - - if len(file.Stmts) == 0 { - return false - } - - // Simple heuristic: if it looks like prose rather than commands, don't flag it - if len(file.Stmts) == 1 { - stmt := file.Stmts[0] - if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { - if len(callExpr.Args) > 2 { - allText := true - for i, arg := range callExpr.Args { - if i == 0 { - continue - } - argStr := getWordValue(arg) - if strings.HasPrefix(argStr, "-") { - allText = false - break - } - } - if allText { - return false - } - } - } - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - return false -} - -func getWordValue(word *syntax.Word) string { - if word == nil || len(word.Parts) == 0 { - return "" - } - - var result strings.Builder - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.Lit: - result.WriteString(p.Value) - case *syntax.SglQuoted: - result.WriteString(p.Value) - case *syntax.DblQuoted: - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - result.WriteString(lit.Value) - } - } - } - } - return result.String() -} - -func validateCommand(command string) error { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(command), "") - if err != nil { - parts := strings.Fields(command) - if len(parts) > 0 { - baseCmd := parts[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return fmt.Errorf("command '%s' is not allowed", baseCmd) - } - } - } - return nil - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return fmt.Errorf("command contains banned operations") - } - } - return nil -} - var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", "sudo", + "http-prompt", "chrome", "firefox", "safari", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -515,9 +289,11 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - - if err := validateCommand(params.Command); err != nil { - return NewTextErrorResponse(err.Error()), nil + baseCmd := strings.Fields(params.Command)[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil + } } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 768a3738dd2fe162838f32ee498ac14ed2ba9eee..a810002749408af2bb89cb958b5999dc2da3bcb3 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,168 +94,3 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } - -func TestValidateCommand(t *testing.T) { - tests := []struct { - name string - command string - shouldError bool - }{ - // Commands that should be blocked - { - name: "direct sudo", - command: "sudo ls", - shouldError: true, - }, - { - name: "sudo in script", - command: "bash -c 'sudo ls'", - shouldError: true, - }, - { - name: "sudo in command substitution", - command: "$(sudo whoami)", - shouldError: true, - }, - { - name: "sudo in echo command substitution", - command: "echo $(sudo id)", - shouldError: true, - }, - { - name: "sudo in command chain", - command: "ls && sudo rm file", - shouldError: true, - }, - { - name: "sudo in if statement", - command: "if true; then sudo ls; fi", - shouldError: true, - }, - { - name: "sudo in for loop", - command: "for i in 1; do sudo echo $i; done", - shouldError: true, - }, - { - name: "direct curl", - command: "curl http://example.com", - shouldError: true, - }, - { - name: "curl in script", - command: "bash -c 'curl malicious.com'", - shouldError: true, - }, - { - name: "wget command", - command: "wget http://example.com", - shouldError: true, - }, - { - name: "nc command", - command: "nc -l 8080", - shouldError: true, - }, - // Commands that should be allowed - { - name: "simple ls", - command: "ls -la", - shouldError: false, - }, - { - name: "echo command", - command: "echo hello", - shouldError: false, - }, - { - name: "git status", - command: "git status", - shouldError: false, - }, - { - name: "go build", - command: "go build", - shouldError: false, - }, - { - name: "sudo as literal text", - command: "echo 'sudo is just text here'", - shouldError: false, - }, - { - name: "complex allowed command", - command: "find . -name '*.go' | head -10", - shouldError: false, - }, - { - name: "command with environment variables", - command: "FOO=bar go test", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -} - -func TestContainsBannedCommand(t *testing.T) { - // Test the helper functions directly with some edge cases - tests := []struct { - name string - command string - shouldError bool - }{ - { - name: "nested command substitution", - command: "echo $(echo $(sudo id))", - shouldError: true, - }, - { - name: "subshell with banned command", - command: "(sudo ls)", - shouldError: true, - }, - { - name: "case statement with banned command", - command: "case $1 in start) sudo systemctl start service ;; esac", - shouldError: true, - }, - { - name: "while loop with banned command", - command: "while true; do sudo echo test; done", - shouldError: true, - }, - { - name: "function with banned command", - command: "function test() { sudo ls; }", - shouldError: true, - }, - { - name: "complex valid command", - command: "if [ -f file ]; then echo exists; else echo missing; fi", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -} From f99f50427a16e36701abf1e965dd1c605c42557f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:21:19 +0200 Subject: [PATCH 3/6] refactor: improve command blocking system and fix test isolation - Simplify command blocking logic by using utility functions instead of complex closures - Add sudo to banned commands list - Move command blocking from bash tool to shell layer for better separation of concerns - Add comprehensive tests for command blocking functionality - Fix test isolation by using temporary directories to prevent npm package files from polluting source tree - Remove redundant command validation logic from bash tool --- internal/llm/tools/bash.go | 24 ++++-- internal/shell/command_block_test.go | 123 +++++++++++++++++++++++++++ internal/shell/shell.go | 81 ++++++++++++++++-- 3 files changed, 213 insertions(+), 15 deletions(-) create mode 100644 internal/shell/command_block_test.go diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..2ae1c2956c46fef6f2cc1f0a8a20114bbb8785c1 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -43,7 +43,7 @@ const ( var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -244,7 +244,22 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } +func createCommandBlockFuncs() []shell.CommandBlockFunc { + return []shell.CommandBlockFunc{ + shell.CreateSimpleCommandBlocker(bannedCommands), + shell.CreateSubCommandBlocker([][]string{ + {"brew", "install"}, + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + } +} + func NewBashTool(permission permission.Service, workingDir string) BaseTool { + // Set up command blocking on the persistent shell + persistentShell := shell.GetPersistentShell(workingDir) + persistentShell.SetBlockFuncs(createCommandBlockFuncs()) + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -289,13 +304,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go new file mode 100644 index 0000000000000000000000000000000000000000..85971748f882cf79fba3ea86d2682ce6ce4f252d --- /dev/null +++ b/internal/shell/command_block_test.go @@ -0,0 +1,123 @@ +package shell + +import ( + "context" + "os" + "strings" + "testing" +) + +func TestCommandBlocking(t *testing.T) { + tests := []struct { + name string + blockFuncs []CommandBlockFunc + command string + shouldBlock bool + }{ + { + name: "block simple command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "curl https://example.com", + shouldBlock: true, + }, + { + name: "allow non-blocked command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "echo hello", + shouldBlock: false, + }, + { + name: "block subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew install wget", + shouldBlock: true, + }, + { + name: "allow different subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew list", + shouldBlock: false, + }, + { + name: "block npm global install with -g", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install -g typescript", + shouldBlock: true, + }, + { + name: "block npm global install with --global", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install --global typescript", + shouldBlock: true, + }, + { + name: "allow npm local install", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install typescript", + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for each test + tmpDir, err := os.MkdirTemp("", "shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + shell := NewShell(&Options{ + WorkingDir: tmpDir, + BlockFuncs: tt.blockFuncs, + }) + + _, _, err = shell.Exec(context.Background(), tt.command) + + if tt.shouldBlock { + if err == nil { + t.Errorf("Expected command to be blocked, but it was allowed") + } else if !strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Expected security error, got: %v", err) + } + } else { + // For non-blocked commands, we might get other errors (like command not found) + // but we shouldn't get the security error + if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Command was unexpectedly blocked: %v", err) + } + } + }) + } +} \ No newline at end of file diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0467c9072c5111e4b4ea9a5439519e4edf76af46..815be0907a3fd05f996a24f84f751ce5d776b833 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,12 +44,16 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} +// CommandBlockFunc is a function that determines if a command should be blocked +type CommandBlockFunc func(args []string) bool + // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []CommandBlockFunc } // Options for creating a new shell @@ -57,6 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger + BlockFuncs []CommandBlockFunc } // NewShell creates a new shell instance with the given options @@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, } } @@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) { s.env = append(s.env, keyPrefix+value) } +// SetBlockFuncs sets the command block functions for the shell +func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.blockFuncs = blockFuncs +} + // Windows-specific commands that should use native shell var windowsNativeCommands = map[string]bool{ "dir": true, @@ -203,6 +216,59 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } +// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches +func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { + bannedSet := make(map[string]bool) + for _, cmd := range bannedCommands { + bannedSet[cmd] = true + } + + return func(args []string) bool { + if len(args) == 0 { + return false + } + return bannedSet[args[0]] + } +} + +// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands +func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { + return func(args []string) bool { + for _, blocked := range blockedSubCommands { + if len(args) >= len(blocked) { + match := true + for i, part := range blocked { + if args[i] != part { + match = false + break + } + } + if match { + return true + } + } + } + return false + } +} +func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + for _, blockFunc := range s.blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " ")) + } + } + + return next(ctx, args) + } + } +} + // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd @@ -291,6 +357,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), + interp.ExecHandlers(s.createCommandBlockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) From 931029e1579b57948322bdeda0a7c013202b90c8 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:28:42 +0200 Subject: [PATCH 4/6] chore: fix errors not showing in bash tool --- internal/llm/tools/bash.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 2ae1c2956c46fef6f2cc1f0a8a20114bbb8785c1..feac86e01cc7769d886f90a926ad20b87741702d 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "runtime" "strings" "time" @@ -259,7 +260,7 @@ func NewBashTool(permission permission.Service, workingDir string) BaseTool { // Set up command blocking on the persistent shell persistentShell := shell.GetPersistentShell(workingDir) persistentShell.SetBlockFuncs(createCommandBlockFuncs()) - + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -357,7 +358,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) + slog.Info("Bash command executed", + "command", params.Command, + "stdout", stdout, + "stderr", stderr, + "exit_code", exitCode, + "interrupted", interrupted, + "err", err, + ) + errorMessage := stderr + if errorMessage == "" && err != nil { + errorMessage = err.Error() + } + if interrupted { if errorMessage != "" { errorMessage += "\n" From 2fdbcac0290e1aa977de0139c335210ce35b0faf Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 17:55:01 +0200 Subject: [PATCH 5/6] chore: better naming --- internal/llm/tools/bash.go | 101 ++++++++++++++++++++++++--- internal/shell/command_block_test.go | 24 +++---- internal/shell/shell.go | 25 +++---- 3 files changed, 118 insertions(+), 32 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index feac86e01cc7769d886f90a926ad20b87741702d..6d7a9a32b3829da02021be80e6e41e28888efd83 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -42,9 +42,74 @@ const ( ) var bannedCommands = []string{ - "alias", "curl", "curlie", "wget", "axel", "aria2c", - "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", "sudo", + // Network/Download tools + "alias", + "aria2c", + "axel", + "chrome", + "curl", + "curlie", + "firefox", + "http-prompt", + "httpie", + "links", + "lynx", + "nc", + "safari", + "telnet", + "w3m", + "wget", + "xh", + + // System administration + "doas", + "su", + "sudo", + + // Package managers + "apk", + "apt", + "apt-cache", + "apt-get", + "dnf", + "dpkg", + "emerge", + "home-manager", + "makepkg", + "opkg", + "pacman", + "paru", + "pkg", + "pkg_add", + "pkg_delete", + "portage", + "rpm", + "yay", + "yum", + "zypper", + + // System modification + "at", + "batch", + "chkconfig", + "crontab", + "fdisk", + "mkfs", + "mount", + "parted", + "service", + "systemctl", + "umount", + + // Network configuration + "firewall-cmd", + "ifconfig", + "ip", + "iptables", + "netstat", + "pfctl", + "route", + "ufw", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -245,13 +310,33 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } -func createCommandBlockFuncs() []shell.CommandBlockFunc { - return []shell.CommandBlockFunc{ - shell.CreateSimpleCommandBlocker(bannedCommands), - shell.CreateSubCommandBlocker([][]string{ +func blockFuncs() []shell.BlockFunc { + return []shell.BlockFunc{ + shell.CommandsBlocker(bannedCommands), + shell.ArgumentsBlocker([][]string{ + // System package managers + {"apk", "add"}, + {"apt", "install"}, + {"apt-get", "install"}, + {"dnf", "install"}, + {"emerge"}, + {"pacman", "-S"}, + {"pkg", "install"}, + {"yum", "install"}, + {"zypper", "install"}, + + // Language-specific package managers {"brew", "install"}, + {"cargo", "install"}, + {"gem", "install"}, + {"go", "install"}, {"npm", "install", "-g"}, {"npm", "install", "--global"}, + {"pip", "install", "--user"}, + {"pip3", "install", "--user"}, + {"pnpm", "add", "-g"}, + {"pnpm", "add", "--global"}, + {"yarn", "global", "add"}, }), } } @@ -259,7 +344,7 @@ func createCommandBlockFuncs() []shell.CommandBlockFunc { func NewBashTool(permission permission.Service, workingDir string) BaseTool { // Set up command blocking on the persistent shell persistentShell := shell.GetPersistentShell(workingDir) - persistentShell.SetBlockFuncs(createCommandBlockFuncs()) + persistentShell.SetBlockFuncs(blockFuncs()) return &bashTool{ permissions: permission, diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go index 85971748f882cf79fba3ea86d2682ce6ce4f252d..fd7c46bcd98e54f44abbe982e834f3cbb04cbfa4 100644 --- a/internal/shell/command_block_test.go +++ b/internal/shell/command_block_test.go @@ -10,13 +10,13 @@ import ( func TestCommandBlocking(t *testing.T) { tests := []struct { name string - blockFuncs []CommandBlockFunc + blockFuncs []BlockFunc command string shouldBlock bool }{ { name: "block simple command", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) > 0 && args[0] == "curl" }, @@ -26,7 +26,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow non-blocked command", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) > 0 && args[0] == "curl" }, @@ -36,7 +36,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block subcommand", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) >= 2 && args[0] == "brew" && args[1] == "install" }, @@ -46,7 +46,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow different subcommand", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) >= 2 && args[0] == "brew" && args[1] == "install" }, @@ -56,8 +56,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block npm global install with -g", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -67,8 +67,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block npm global install with --global", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -78,8 +78,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow npm local install", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -120,4 +120,4 @@ func TestCommandBlocking(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 815be0907a3fd05f996a24f84f751ce5d776b833..097af74d0172264efbe899711ff57137abc6ee30 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,8 +44,8 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} -// CommandBlockFunc is a function that determines if a command should be blocked -type CommandBlockFunc func(args []string) bool +// BlockFunc is a function that determines if a command should be blocked +type BlockFunc func(args []string) bool // Shell provides cross-platform shell execution with optional state persistence type Shell struct { @@ -53,7 +53,7 @@ type Shell struct { cwd string mu sync.Mutex logger Logger - blockFuncs []CommandBlockFunc + blockFuncs []BlockFunc } // Options for creating a new shell @@ -61,7 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger - BlockFuncs []CommandBlockFunc + BlockFuncs []BlockFunc } // NewShell creates a new shell instance with the given options @@ -159,7 +159,7 @@ func (s *Shell) SetEnv(key, value string) { } // SetBlockFuncs sets the command block functions for the shell -func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) { +func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) { s.mu.Lock() defer s.mu.Unlock() s.blockFuncs = blockFuncs @@ -216,13 +216,13 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } -// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches -func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { +// CommandsBlocker creates a BlockFunc that blocks exact command matches +func CommandsBlocker(bannedCommands []string) BlockFunc { bannedSet := make(map[string]bool) for _, cmd := range bannedCommands { bannedSet[cmd] = true } - + return func(args []string) bool { if len(args) == 0 { return false @@ -231,8 +231,8 @@ func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { } } -// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands -func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { +// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands +func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc { return func(args []string) bool { for _, blocked := range blockedSubCommands { if len(args) >= len(blocked) { @@ -251,7 +251,8 @@ func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { return false } } -func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + +func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { return func(ctx context.Context, args []string) error { if len(args) == 0 { @@ -357,7 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), - interp.ExecHandlers(s.createCommandBlockHandler()), + interp.ExecHandlers(s.blockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) From 0a5274840166e83a8d2e3fbbb8fcec2a6f324e26 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:02:42 +0200 Subject: [PATCH 6/6] chore: small change --- internal/shell/shell.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 097af74d0172264efbe899711ff57137abc6ee30..b655c5dbecf5b69c7ad102c53108733515138771 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -261,7 +261,7 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand for _, blockFunc := range s.blockFuncs { if blockFunc(args) { - return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " ")) + return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " ")) } }