feat(permissions): require a permission prompt for chained commands

Kieran Klukas created

Change summary

internal/agent/tools/bash.go      | 12 +++--
internal/agent/tools/bash_test.go | 77 +++++++++++++++++++++++++++++++++
internal/agent/tools/safe.go      | 22 +++++++++
internal/agent/tools/safe_test.go | 47 ++++++++++++++++++++
4 files changed, 152 insertions(+), 6 deletions(-)

Detailed changes

internal/agent/tools/bash.go 🔗

@@ -205,11 +205,13 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution
 			isSafeReadOnly := false
 			cmdLower := strings.ToLower(params.Command)
 
-			for _, safe := range safeCommands {
-				if strings.HasPrefix(cmdLower, safe) {
-					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
-						isSafeReadOnly = true
-						break
+			if !containsCommandChaining(params.Command) {
+				for _, safe := range safeCommands {
+					if strings.HasPrefix(cmdLower, safe) {
+						if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
+							isSafeReadOnly = true
+							break
+						}
 					}
 				}
 			}

internal/agent/tools/bash_test.go 🔗

@@ -79,12 +79,89 @@ func TestBashTool_CustomAutoBackgroundThreshold(t *testing.T) {
 	require.NoError(t, bgManager.Kill(meta.ShellID))
 }
 
+type recordingPermissionService struct {
+	*pubsub.Broker[permission.PermissionRequest]
+	requestCount int
+	allow        bool
+}
+
+func (m *recordingPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) {
+	m.requestCount++
+	return m.allow, nil
+}
+
+func (m *recordingPermissionService) Grant(req permission.PermissionRequest) {}
+
+func (m *recordingPermissionService) Deny(req permission.PermissionRequest) {}
+
+func (m *recordingPermissionService) GrantPersistent(req permission.PermissionRequest) {}
+
+func (m *recordingPermissionService) AutoApproveSession(sessionID string) {}
+
+func (m *recordingPermissionService) SetSkipRequests(skip bool) {}
+
+func (m *recordingPermissionService) SkipRequests() bool {
+	return false
+}
+
+func (m *recordingPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
+	return make(<-chan pubsub.Event[permission.PermissionNotification])
+}
+
 func newBashToolForTest(workingDir string) fantasy.AgentTool {
 	permissions := &mockBashPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
 	attribution := &config.Attribution{TrailerStyle: config.TrailerStyleNone}
 	return NewBashTool(permissions, workingDir, attribution, "test-model")
 }
 
+func newBashToolWithRecordingPerms(workingDir string, allow bool) (fantasy.AgentTool, *recordingPermissionService) {
+	perms := &recordingPermissionService{
+		Broker: pubsub.NewBroker[permission.PermissionRequest](),
+		allow:  allow,
+	}
+	attribution := &config.Attribution{TrailerStyle: config.TrailerStyleNone}
+	return NewBashTool(perms, workingDir, attribution, "test-model"), perms
+}
+
+func TestBashTool_ChainedCommandsRequirePermission(t *testing.T) {
+	workingDir := t.TempDir()
+	tool, perms := newBashToolWithRecordingPerms(workingDir, true)
+	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
+
+	// ls && echo should trigger permission check.
+	resp := runBashTool(t, tool, ctx, BashParams{
+		Description: "chained ls",
+		Command:     "ls && echo done",
+	})
+
+	require.False(t, resp.IsError)
+	require.Equal(t, 1, perms.requestCount, "chained command should trigger permission request")
+
+	// Plain ls should NOT trigger permission check.
+	perms.requestCount = 0
+	resp = runBashTool(t, tool, ctx, BashParams{
+		Description: "plain ls",
+		Command:     "ls -la",
+	})
+
+	require.False(t, resp.IsError)
+	require.Equal(t, 0, perms.requestCount, "plain ls should not trigger permission request")
+}
+
+func TestBashTool_ChainedCommandsDenied(t *testing.T) {
+	workingDir := t.TempDir()
+	tool, perms := newBashToolWithRecordingPerms(workingDir, false)
+	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
+
+	resp := runBashTool(t, tool, ctx, BashParams{
+		Description: "chained ls denied",
+		Command:     "ls && rm -rf /",
+	})
+
+	require.Equal(t, 1, perms.requestCount)
+	require.Contains(t, resp.Content, "User denied permission")
+}
+
 func runBashTool(t *testing.T, tool fantasy.AgentTool, ctx context.Context, params BashParams) fantasy.ToolResponse {
 	t.Helper()
 

internal/agent/tools/safe.go 🔗

@@ -1,6 +1,10 @@
 package tools
 
-import "runtime"
+import (
+	"runtime"
+	"slices"
+	"strings"
+)
 
 var safeCommands = []string{
 	// Bash builtins and core utils
@@ -54,6 +58,22 @@ var safeCommands = []string{
 	"git tag",
 }
 
+var chainingMetacharacters = []string{
+	";",
+	"|",
+	"&&",
+	"$(",
+	"`",
+}
+
+// containsCommandChaining reports whether s contains shell metacharacters
+// that enable command chaining or substitution.
+func containsCommandChaining(s string) bool {
+	return slices.ContainsFunc(chainingMetacharacters, func(c string) bool {
+		return strings.Contains(s, c)
+	})
+}
+
 func init() {
 	if runtime.GOOS == "windows" {
 		safeCommands = append(

internal/agent/tools/safe_test.go 🔗

@@ -0,0 +1,47 @@
+package tools
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestContainsCommandChaining(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name     string
+		input    string
+		expected bool
+	}{
+		{"plain ls", "ls -la", false},
+		{"plain echo", "echo hello world", false},
+		{"plain pwd", "pwd", false},
+		{"plain git status", "git status", false},
+		{"ls with redirect", "ls > /tmp/out", false},
+		{"ls with pipe", "ls | grep foo", true},
+		{"ls with double ampersand", "ls && echo done", true},
+		{"ls with semicolon", "ls; echo done", true},
+		{"ls with pipe pipe", "ls || echo fail", true},
+		{"ls with backticks", "ls `echo foo`", true},
+		{"ls with subshell", "ls $(echo foo)", true},
+		{"ls with background ampersand", "ls & echo done", false},
+		{"rm -rf with && ls (rm first)", "rm -rf / && ls", true},
+		{"redirect with ampersand gt", "ls &> /dev/null", false},
+		{"redirect with gt ampersand", "ls >& /dev/null", false},
+		{"simple kill", "kill 1234", false},
+		{"kill with pipe", "kill 1234 | echo foo", true},
+		{"git log", "git log --oneline", false},
+		{"git log with pipe", "git log | head", true},
+		{"empty string", "", false},
+		{"dollar sign in argument", "echo $HOME", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			got := containsCommandChaining(tt.input)
+			assert.Equal(t, tt.expected, got, "containsCommandChaining(%q)", tt.input)
+		})
+	}
+}