add command whitelist

Tai Groot created

Change summary

internal/app/app.go                    |  6 +
internal/config/config.go              |  5 
internal/permission/permission.go      | 10 ++
internal/permission/permission_test.go | 92 ++++++++++++++++++++++++++++
4 files changed, 109 insertions(+), 4 deletions(-)

Detailed changes

internal/app/app.go 🔗

@@ -60,12 +60,16 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 	messages := message.NewService(q)
 	files := history.NewService(q, conn)
 	skipPermissionsRequests := cfg.Options != nil && cfg.Options.SkipPermissionsRequests
+	allowedCommands := []string{}
+	if cfg.Options != nil && cfg.Options.AllowedCommands != nil {
+		allowedCommands = cfg.Options.AllowedCommands
+	}
 
 	app := &App{
 		Sessions:    sessions,
 		Messages:    messages,
 		History:     files,
-		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests),
+		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedCommands),
 		LSPClients:  make(map[string]*lsp.Client),
 
 		globalCtx: ctx,

internal/config/config.go 🔗

@@ -126,8 +126,9 @@ type Options struct {
 	Debug                   bool        `json:"debug,omitempty"`
 	DebugLSP                bool        `json:"debug_lsp,omitempty"`
 	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
-	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
-	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
+	DataDirectory           string      `json:"data_directory,omitempty"`   // Relative to the cwd
+	SkipPermissionsRequests bool        `json:"-"`                          // Automatically accept all permissions (YOLO mode)
+	AllowedCommands         []string    `json:"allowed_commands,omitempty"` // Commands that don't require permission prompts
 }
 
 type MCPs map[string]MCPConfig

internal/permission/permission.go 🔗

@@ -50,6 +50,7 @@ type permissionService struct {
 	autoApproveSessions   []string
 	autoApproveSessionsMu sync.RWMutex
 	skip                  bool
+	allowedCommands       []string
 }
 
 func (s *permissionService) GrantPersistent(permission PermissionRequest) {
@@ -82,6 +83,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
 		return true
 	}
 
+	// Check if the tool/action combination is in the allowlist
+	commandKey := opts.ToolName + ":" + opts.Action
+	if slices.Contains(s.allowedCommands, commandKey) || slices.Contains(s.allowedCommands, opts.ToolName) {
+		return true
+	}
+
 	s.autoApproveSessionsMu.RLock()
 	autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
 	s.autoApproveSessionsMu.RUnlock()
@@ -130,11 +137,12 @@ func (s *permissionService) AutoApproveSession(sessionID string) {
 	s.autoApproveSessionsMu.Unlock()
 }
 
-func NewPermissionService(workingDir string, skip bool) Service {
+func NewPermissionService(workingDir string, skip bool, allowedCommands []string) Service {
 	return &permissionService{
 		Broker:             pubsub.NewBroker[PermissionRequest](),
 		workingDir:         workingDir,
 		sessionPermissions: make([]PermissionRequest, 0),
 		skip:               skip,
+		allowedCommands:    allowedCommands,
 	}
 }

internal/permission/permission_test.go 🔗

@@ -0,0 +1,92 @@
+package permission
+
+import (
+	"testing"
+)
+
+func TestPermissionService_AllowedCommands(t *testing.T) {
+	tests := []struct {
+		name            string
+		allowedCommands []string
+		toolName        string
+		action          string
+		expected        bool
+	}{
+		{
+			name:            "tool in allowlist",
+			allowedCommands: []string{"bash", "view"},
+			toolName:        "bash",
+			action:          "execute",
+			expected:        true,
+		},
+		{
+			name:            "tool:action in allowlist",
+			allowedCommands: []string{"bash:execute", "edit:create"},
+			toolName:        "bash",
+			action:          "execute",
+			expected:        true,
+		},
+		{
+			name:            "tool not in allowlist",
+			allowedCommands: []string{"view", "ls"},
+			toolName:        "bash",
+			action:          "execute",
+			expected:        false,
+		},
+		{
+			name:            "tool:action not in allowlist",
+			allowedCommands: []string{"bash:read", "edit:create"},
+			toolName:        "bash",
+			action:          "execute",
+			expected:        false,
+		},
+		{
+			name:            "empty allowlist",
+			allowedCommands: []string{},
+			toolName:        "bash",
+			action:          "execute",
+			expected:        false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			service := NewPermissionService("/tmp", false, tt.allowedCommands)
+
+			// Create a channel to capture the permission request
+			// Since we're testing the allowlist logic, we need to simulate the request
+			ps := service.(*permissionService)
+
+			// Test the allowlist logic directly
+			commandKey := tt.toolName + ":" + tt.action
+			allowed := false
+			for _, cmd := range ps.allowedCommands {
+				if cmd == commandKey || cmd == tt.toolName {
+					allowed = true
+					break
+				}
+			}
+
+			if allowed != tt.expected {
+				t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v",
+					tt.expected, allowed, tt.toolName, tt.action, tt.allowedCommands)
+			}
+		})
+	}
+}
+
+func TestPermissionService_SkipMode(t *testing.T) {
+	service := NewPermissionService("/tmp", true, []string{})
+
+	result := service.Request(CreatePermissionRequest{
+		SessionID:   "test-session",
+		ToolName:    "bash",
+		Action:      "execute",
+		Description: "test command",
+		Path:        "/tmp",
+	})
+
+	if !result {
+		t.Error("expected permission to be granted in skip mode")
+	}
+}