permission_test.go

  1package permission
  2
  3import (
  4	"context"
  5	"sync"
  6	"testing"
  7	"time"
  8
  9	"github.com/stretchr/testify/assert"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestPermissionService_AllowedCommands(t *testing.T) {
 14	tests := []struct {
 15		name         string
 16		allowedTools []string
 17		toolName     string
 18		action       string
 19		expected     bool
 20	}{
 21		{
 22			name:         "tool in allowlist",
 23			allowedTools: []string{"bash", "view"},
 24			toolName:     "bash",
 25			action:       "execute",
 26			expected:     true,
 27		},
 28		{
 29			name:         "tool:action in allowlist",
 30			allowedTools: []string{"bash:execute", "edit:create"},
 31			toolName:     "bash",
 32			action:       "execute",
 33			expected:     true,
 34		},
 35		{
 36			name:         "tool not in allowlist",
 37			allowedTools: []string{"view", "ls"},
 38			toolName:     "bash",
 39			action:       "execute",
 40			expected:     false,
 41		},
 42		{
 43			name:         "tool:action not in allowlist",
 44			allowedTools: []string{"bash:read", "edit:create"},
 45			toolName:     "bash",
 46			action:       "execute",
 47			expected:     false,
 48		},
 49		{
 50			name:         "empty allowlist",
 51			allowedTools: []string{},
 52			toolName:     "bash",
 53			action:       "execute",
 54			expected:     false,
 55		},
 56	}
 57
 58	for _, tt := range tests {
 59		t.Run(tt.name, func(t *testing.T) {
 60			service := NewPermissionService(t.Context(), "/tmp", false, tt.allowedTools, nil)
 61
 62			// Create a channel to capture the permission request
 63			// Since we're testing the allowlist logic, we need to simulate the request
 64			ps := service.(*permissionService)
 65
 66			// Test the allowlist logic directly
 67			commandKey := tt.toolName + ":" + tt.action
 68			allowed := false
 69			for _, cmd := range ps.allowedTools {
 70				if cmd == commandKey || cmd == tt.toolName {
 71					allowed = true
 72					break
 73				}
 74			}
 75
 76			if allowed != tt.expected {
 77				t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v",
 78					tt.expected, allowed, tt.toolName, tt.action, tt.allowedTools)
 79			}
 80		})
 81	}
 82}
 83
 84func TestPermissionService_SkipMode(t *testing.T) {
 85	service := NewPermissionService(t.Context(), "/tmp", true, []string{}, nil)
 86
 87	result := service.Request(CreatePermissionRequest{
 88		SessionID:   "test-session",
 89		ToolName:    "bash",
 90		Action:      "execute",
 91		Description: "test command",
 92		Path:        "/tmp",
 93	})
 94
 95	if !result {
 96		t.Error("expected permission to be granted in skip mode")
 97	}
 98}
 99
100func TestPermissionService_SequentialProperties(t *testing.T) {
101	t.Run("Sequential permission requests with persistent grants", func(t *testing.T) {
102		service := NewPermissionService(t.Context(), "/tmp", false, []string{}, nil)
103
104		req1 := CreatePermissionRequest{
105			SessionID:   "session1",
106			ToolName:    "file_tool",
107			Description: "Read file",
108			Action:      "read",
109			Params:      map[string]string{"file": "test.txt"},
110			Path:        "/tmp/test.txt",
111		}
112
113		var result1 bool
114		var wg sync.WaitGroup
115		wg.Add(1)
116
117		events := service.Subscribe(t.Context())
118
119		go func() {
120			defer wg.Done()
121			result1 = service.Request(req1)
122		}()
123
124		var permissionReq PermissionRequest
125		event := <-events
126
127		permissionReq = event.Payload
128		service.GrantPersistent(permissionReq)
129
130		wg.Wait()
131		assert.True(t, result1, "First request should be granted")
132
133		// Second identical request should be automatically approved due to persistent permission
134		req2 := CreatePermissionRequest{
135			SessionID:   "session1",
136			ToolName:    "file_tool",
137			Description: "Read file again",
138			Action:      "read",
139			Params:      map[string]string{"file": "test.txt"},
140			Path:        "/tmp/test.txt",
141		}
142		result2 := service.Request(req2)
143		assert.True(t, result2, "Second request should be auto-approved")
144	})
145	t.Run("Sequential requests with temporary grants", func(t *testing.T) {
146		service := NewPermissionService(t.Context(), "/tmp", false, []string{}, nil)
147
148		req := CreatePermissionRequest{
149			SessionID:   "session2",
150			ToolName:    "file_tool",
151			Description: "Write file",
152			Action:      "write",
153			Params:      map[string]string{"file": "test.txt"},
154			Path:        "/tmp/test.txt",
155		}
156
157		events := service.Subscribe(t.Context())
158		var result1 bool
159		var wg sync.WaitGroup
160
161		wg.Go(func() {
162			result1 = service.Request(req)
163		})
164
165		var permissionReq PermissionRequest
166		event := <-events
167		permissionReq = event.Payload
168
169		service.Grant(permissionReq)
170		wg.Wait()
171		assert.True(t, result1, "First request should be granted")
172
173		var result2 bool
174
175		wg.Go(func() {
176			result2 = service.Request(req)
177		})
178
179		event = <-events
180		permissionReq = event.Payload
181		service.Deny(permissionReq)
182		wg.Wait()
183		assert.False(t, result2, "Second request should be denied")
184	})
185	t.Run("Concurrent requests with different outcomes", func(t *testing.T) {
186		service := NewPermissionService(t.Context(), "/tmp", false, []string{}, nil)
187
188		events := service.Subscribe(t.Context())
189
190		var wg sync.WaitGroup
191
192		requests := []CreatePermissionRequest{
193			{
194				SessionID:   "concurrent1",
195				ToolName:    "tool1",
196				Action:      "action1",
197				Path:        "/tmp/file1.txt",
198				Description: "First concurrent request",
199			},
200			{
201				SessionID:   "concurrent2",
202				ToolName:    "tool2",
203				Action:      "action2",
204				Path:        "/tmp/file2.txt",
205				Description: "Second concurrent request",
206			},
207			{
208				SessionID:   "concurrent3",
209				ToolName:    "tool3",
210				Action:      "action3",
211				Path:        "/tmp/file3.txt",
212				Description: "Third concurrent request",
213			},
214		}
215
216		results := make([]bool, len(requests))
217
218		for i, req := range requests {
219			wg.Add(1)
220			go func(index int, request CreatePermissionRequest) {
221				defer wg.Done()
222				results[index] = service.Request(request)
223			}(i, req)
224		}
225
226		for range 3 {
227			event := <-events
228			switch event.Payload.ToolName {
229			case "tool1":
230				service.Grant(event.Payload)
231			case "tool2":
232				service.GrantPersistent(event.Payload)
233			case "tool3":
234				service.Deny(event.Payload)
235			}
236		}
237		wg.Wait()
238		grantedCount := 0
239		for _, result := range results {
240			if result {
241				grantedCount++
242			}
243		}
244
245		assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied")
246		secondReq := requests[1]
247		secondReq.Description = "Repeat of second request"
248		result := service.Request(secondReq)
249		assert.True(t, result, "Repeated request should be auto-approved due to persistent permission")
250	})
251}
252
253func TestPermissionService_NotifierSchedulesAndCancelsOnGrant(t *testing.T) {
254	notifier := &testPermissionNotifier{}
255	service := NewPermissionService(t.Context(), "/tmp", false, []string{}, notifier)
256
257	events := service.Subscribe(t.Context())
258	req := CreatePermissionRequest{
259		SessionID:  "session-grant",
260		ToolCallID: "tool-call-grant",
261		ToolName:   "bash",
262		Action:     "execute",
263		Path:       "/tmp/script.sh",
264	}
265
266	var wg sync.WaitGroup
267	wg.Add(1)
268	go func() {
269		defer wg.Done()
270		result := service.Request(req)
271		require.True(t, result, "Request should be granted")
272	}()
273
274	event := <-events
275	calls := notifier.Calls()
276	require.Len(t, calls, 1)
277	assert.Equal(t, "💘 Crush is waiting", calls[0].title)
278	assert.Equal(t, "Permission required to execute \"bash\"", calls[0].message)
279	assert.Equal(t, permissionNotificationDelay, calls[0].delay)
280
281	service.Grant(event.Payload)
282	wg.Wait()
283
284	calls = notifier.Calls()
285	require.Len(t, calls, 1)
286	assert.Equal(t, 1, calls[0].cancelCount)
287}
288
289func TestPermissionService_NotifyInteractionCancelsNotification(t *testing.T) {
290	notifier := &testPermissionNotifier{}
291	service := NewPermissionService(t.Context(), "/tmp", false, []string{}, notifier)
292
293	events := service.Subscribe(t.Context())
294	req := CreatePermissionRequest{
295		SessionID:  "session-interact",
296		ToolCallID: "tool-call-interact",
297		ToolName:   "edit",
298		Action:     "write",
299		Path:       "/tmp/file.txt",
300	}
301
302	var wg sync.WaitGroup
303	wg.Add(1)
304	go func() {
305		defer wg.Done()
306		result := service.Request(req)
307		require.False(t, result, "Request should be denied")
308	}()
309
310	event := <-events
311	calls := notifier.Calls()
312	require.Len(t, calls, 1)
313
314	service.NotifyInteraction(event.Payload.ToolCallID)
315	calls = notifier.Calls()
316	require.Len(t, calls, 1)
317	assert.Equal(t, 1, calls[0].cancelCount)
318
319	service.Deny(event.Payload)
320	wg.Wait()
321}
322
323type notificationCall struct {
324	title       string
325	message     string
326	delay       time.Duration
327	cancelCount int
328}
329
330type testPermissionNotifier struct {
331	mu    sync.Mutex
332	calls []*notificationCall
333}
334
335func (n *testPermissionNotifier) NotifyPermissionRequest(ctx context.Context, title, message string, delay time.Duration) context.CancelFunc {
336	call := &notificationCall{
337		title:   title,
338		message: message,
339		delay:   delay,
340	}
341	n.mu.Lock()
342	n.calls = append(n.calls, call)
343	n.mu.Unlock()
344
345	return func() {
346		n.mu.Lock()
347		call.cancelCount++
348		n.mu.Unlock()
349	}
350}
351
352func (n *testPermissionNotifier) Calls() []notificationCall {
353	n.mu.Lock()
354	defer n.mu.Unlock()
355
356	result := make([]notificationCall, len(n.calls))
357	for i, call := range n.calls {
358		result[i] = *call
359	}
360	return result
361}