From 96728b15d037f187424db63cd6bdb20a01fbe08f Mon Sep 17 00:00:00 2001 From: Kieran Klukas Date: Tue, 12 May 2026 18:23:47 -0400 Subject: [PATCH] feat(permissions): require a permission prompt for chained commands --- internal/agent/tools/bash.go | 12 +++-- internal/agent/tools/bash_test.go | 77 +++++++++++++++++++++++++++++++ internal/agent/tools/safe.go | 22 ++++++++- internal/agent/tools/safe_test.go | 47 +++++++++++++++++++ 4 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 internal/agent/tools/safe_test.go diff --git a/internal/agent/tools/bash.go b/internal/agent/tools/bash.go index ba6b224ef82ea28be7ea299be49703dcba82174f..6b91c584ce30daa66eba8b4c55da5dae0ce64529 100644 --- a/internal/agent/tools/bash.go +++ b/internal/agent/tools/bash.go @@ -205,11 +205,13 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) - for _, safe := range safeCommands { - if strings.HasPrefix(cmdLower, safe) { - if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { - isSafeReadOnly = true - break + if !containsCommandChaining(params.Command) { + for _, safe := range safeCommands { + if strings.HasPrefix(cmdLower, safe) { + if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { + isSafeReadOnly = true + break + } } } } diff --git a/internal/agent/tools/bash_test.go b/internal/agent/tools/bash_test.go index 899a4055217b2ee7a7e0573e7009d6519297146e..b9c4a13adbb1f948c9fb85f5cb762bd79906bd68 100644 --- a/internal/agent/tools/bash_test.go +++ b/internal/agent/tools/bash_test.go @@ -79,12 +79,89 @@ func TestBashTool_CustomAutoBackgroundThreshold(t *testing.T) { require.NoError(t, bgManager.Kill(meta.ShellID)) } +type recordingPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + requestCount int + allow bool +} + +func (m *recordingPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) { + m.requestCount++ + return m.allow, nil +} + +func (m *recordingPermissionService) Grant(req permission.PermissionRequest) {} + +func (m *recordingPermissionService) Deny(req permission.PermissionRequest) {} + +func (m *recordingPermissionService) GrantPersistent(req permission.PermissionRequest) {} + +func (m *recordingPermissionService) AutoApproveSession(sessionID string) {} + +func (m *recordingPermissionService) SetSkipRequests(skip bool) {} + +func (m *recordingPermissionService) SkipRequests() bool { + return false +} + +func (m *recordingPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] { + return make(<-chan pubsub.Event[permission.PermissionNotification]) +} + func newBashToolForTest(workingDir string) fantasy.AgentTool { permissions := &mockBashPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()} attribution := &config.Attribution{TrailerStyle: config.TrailerStyleNone} return NewBashTool(permissions, workingDir, attribution, "test-model") } +func newBashToolWithRecordingPerms(workingDir string, allow bool) (fantasy.AgentTool, *recordingPermissionService) { + perms := &recordingPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } + attribution := &config.Attribution{TrailerStyle: config.TrailerStyleNone} + return NewBashTool(perms, workingDir, attribution, "test-model"), perms +} + +func TestBashTool_ChainedCommandsRequirePermission(t *testing.T) { + workingDir := t.TempDir() + tool, perms := newBashToolWithRecordingPerms(workingDir, true) + ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session") + + // ls && echo should trigger permission check. + resp := runBashTool(t, tool, ctx, BashParams{ + Description: "chained ls", + Command: "ls && echo done", + }) + + require.False(t, resp.IsError) + require.Equal(t, 1, perms.requestCount, "chained command should trigger permission request") + + // Plain ls should NOT trigger permission check. + perms.requestCount = 0 + resp = runBashTool(t, tool, ctx, BashParams{ + Description: "plain ls", + Command: "ls -la", + }) + + require.False(t, resp.IsError) + require.Equal(t, 0, perms.requestCount, "plain ls should not trigger permission request") +} + +func TestBashTool_ChainedCommandsDenied(t *testing.T) { + workingDir := t.TempDir() + tool, perms := newBashToolWithRecordingPerms(workingDir, false) + ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session") + + resp := runBashTool(t, tool, ctx, BashParams{ + Description: "chained ls denied", + Command: "ls && rm -rf /", + }) + + require.Equal(t, 1, perms.requestCount) + require.Contains(t, resp.Content, "User denied permission") +} + func runBashTool(t *testing.T, tool fantasy.AgentTool, ctx context.Context, params BashParams) fantasy.ToolResponse { t.Helper() diff --git a/internal/agent/tools/safe.go b/internal/agent/tools/safe.go index b0e6635393632390cba1e09d1d5df336fb1979cb..4a8925c024906ef97b4431211a2941d054ac134c 100644 --- a/internal/agent/tools/safe.go +++ b/internal/agent/tools/safe.go @@ -1,6 +1,10 @@ package tools -import "runtime" +import ( + "runtime" + "slices" + "strings" +) var safeCommands = []string{ // Bash builtins and core utils @@ -54,6 +58,22 @@ var safeCommands = []string{ "git tag", } +var chainingMetacharacters = []string{ + ";", + "|", + "&&", + "$(", + "`", +} + +// containsCommandChaining reports whether s contains shell metacharacters +// that enable command chaining or substitution. +func containsCommandChaining(s string) bool { + return slices.ContainsFunc(chainingMetacharacters, func(c string) bool { + return strings.Contains(s, c) + }) +} + func init() { if runtime.GOOS == "windows" { safeCommands = append( diff --git a/internal/agent/tools/safe_test.go b/internal/agent/tools/safe_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6a568c9e3ff1eede670fb00957dc66d3944694d8 --- /dev/null +++ b/internal/agent/tools/safe_test.go @@ -0,0 +1,47 @@ +package tools + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContainsCommandChaining(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected bool + }{ + {"plain ls", "ls -la", false}, + {"plain echo", "echo hello world", false}, + {"plain pwd", "pwd", false}, + {"plain git status", "git status", false}, + {"ls with redirect", "ls > /tmp/out", false}, + {"ls with pipe", "ls | grep foo", true}, + {"ls with double ampersand", "ls && echo done", true}, + {"ls with semicolon", "ls; echo done", true}, + {"ls with pipe pipe", "ls || echo fail", true}, + {"ls with backticks", "ls `echo foo`", true}, + {"ls with subshell", "ls $(echo foo)", true}, + {"ls with background ampersand", "ls & echo done", false}, + {"rm -rf with && ls (rm first)", "rm -rf / && ls", true}, + {"redirect with ampersand gt", "ls &> /dev/null", false}, + {"redirect with gt ampersand", "ls >& /dev/null", false}, + {"simple kill", "kill 1234", false}, + {"kill with pipe", "kill 1234 | echo foo", true}, + {"git log", "git log --oneline", false}, + {"git log with pipe", "git log | head", true}, + {"empty string", "", false}, + {"dollar sign in argument", "echo $HOME", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := containsCommandChaining(tt.input) + assert.Equal(t, tt.expected, got, "containsCommandChaining(%q)", tt.input) + }) + } +}