permission_test.go

  1package permission
  2
  3import (
  4	"sync"
  5	"testing"
  6
  7	"github.com/stretchr/testify/assert"
  8	"github.com/stretchr/testify/require"
  9)
 10
 11func TestPermissionService_AllowedCommands(t *testing.T) {
 12	tests := []struct {
 13		name         string
 14		allowedTools []string
 15		toolName     string
 16		action       string
 17		expected     bool
 18	}{
 19		{
 20			name:         "tool in allowlist",
 21			allowedTools: []string{"bash", "view"},
 22			toolName:     "bash",
 23			action:       "execute",
 24			expected:     true,
 25		},
 26		{
 27			name:         "tool:action in allowlist",
 28			allowedTools: []string{"bash:execute", "edit:create"},
 29			toolName:     "bash",
 30			action:       "execute",
 31			expected:     true,
 32		},
 33		{
 34			name:         "tool not in allowlist",
 35			allowedTools: []string{"view", "ls"},
 36			toolName:     "bash",
 37			action:       "execute",
 38			expected:     false,
 39		},
 40		{
 41			name:         "tool:action not in allowlist",
 42			allowedTools: []string{"bash:read", "edit:create"},
 43			toolName:     "bash",
 44			action:       "execute",
 45			expected:     false,
 46		},
 47		{
 48			name:         "empty allowlist",
 49			allowedTools: []string{},
 50			toolName:     "bash",
 51			action:       "execute",
 52			expected:     false,
 53		},
 54	}
 55
 56	for _, tt := range tests {
 57		t.Run(tt.name, func(t *testing.T) {
 58			service := NewPermissionService("/tmp", false, tt.allowedTools)
 59
 60			// Create a channel to capture the permission request
 61			// Since we're testing the allowlist logic, we need to simulate the request
 62			ps := service.(*permissionService)
 63
 64			// Test the allowlist logic directly
 65			commandKey := tt.toolName + ":" + tt.action
 66			allowed := false
 67			for _, cmd := range ps.allowedTools {
 68				if cmd == commandKey || cmd == tt.toolName {
 69					allowed = true
 70					break
 71				}
 72			}
 73
 74			if allowed != tt.expected {
 75				t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v",
 76					tt.expected, allowed, tt.toolName, tt.action, tt.allowedTools)
 77			}
 78		})
 79	}
 80}
 81
 82func TestSkipRace(t *testing.T) {
 83	svc := NewPermissionService("/tmp", false, nil)
 84	var wg sync.WaitGroup
 85	wg.Add(2)
 86	go func() {
 87		defer wg.Done()
 88		svc.SetSkipRequests(true)
 89	}()
 90	go func() {
 91		defer wg.Done()
 92		svc.SkipRequests()
 93	}()
 94	wg.Wait()
 95}
 96
 97func TestPermissionService_SkipMode(t *testing.T) {
 98	service := NewPermissionService("/tmp", true, []string{})
 99
100	result, err := service.Request(t.Context(), CreatePermissionRequest{
101		SessionID:   "test-session",
102		ToolName:    "bash",
103		Action:      "execute",
104		Description: "test command",
105		Path:        "/tmp",
106	})
107	if err != nil {
108		t.Errorf("unexpected error: %v", err)
109	}
110	if !result {
111		t.Error("expected permission to be granted in skip mode")
112	}
113}
114
115func TestPermissionService_HookApproval(t *testing.T) {
116	t.Parallel()
117
118	t.Run("matching tool call ID short-circuits the prompt", func(t *testing.T) {
119		t.Parallel()
120		service := NewPermissionService("/tmp", false, nil)
121
122		ctx := WithHookApproval(t.Context(), "call-42")
123		granted, err := service.Request(ctx, CreatePermissionRequest{
124			SessionID:   "s1",
125			ToolCallID:  "call-42",
126			ToolName:    "bash",
127			Action:      "execute",
128			Description: "hook-approved command",
129			Path:        "/tmp",
130		})
131		require.NoError(t, err)
132		assert.True(t, granted, "hook-approved call should bypass the prompt")
133	})
134
135	t.Run("approval is scoped to the stamped tool call ID", func(t *testing.T) {
136		t.Parallel()
137		service := NewPermissionService("/tmp", false, nil)
138
139		// Stamp for call-42, ask for a different call ID — must not leak.
140		ctx := WithHookApproval(t.Context(), "call-42")
141
142		// Kick off a real request that will need a subscriber to resolve it.
143		events := service.Subscribe(t.Context())
144		var (
145			wg      sync.WaitGroup
146			granted bool
147			err     error
148		)
149		wg.Go(func() {
150			granted, err = service.Request(ctx, CreatePermissionRequest{
151				SessionID:   "s1",
152				ToolCallID:  "call-other",
153				ToolName:    "bash",
154				Action:      "execute",
155				Description: "unrelated call",
156				Path:        "/tmp",
157			})
158		})
159
160		// Confirm the service published a real request (i.e. didn't bypass).
161		event := <-events
162		service.Deny(event.Payload)
163		wg.Wait()
164		require.NoError(t, err)
165		assert.False(t, granted, "stamped approval must not apply to a different tool call")
166	})
167
168	t.Run("notifies subscribers that permission was granted", func(t *testing.T) {
169		t.Parallel()
170		service := NewPermissionService("/tmp", false, nil)
171
172		notifications := service.SubscribeNotifications(t.Context())
173
174		ctx := WithHookApproval(t.Context(), "call-99")
175		granted, err := service.Request(ctx, CreatePermissionRequest{
176			SessionID:  "s1",
177			ToolCallID: "call-99",
178			ToolName:   "view",
179			Action:     "read",
180			Path:       "/tmp",
181		})
182		require.NoError(t, err)
183		assert.True(t, granted)
184
185		event := <-notifications
186		assert.Equal(t, "call-99", event.Payload.ToolCallID)
187		assert.True(t, event.Payload.Granted, "subscribers should see a granted notification")
188	})
189}
190
191func TestPermissionService_SequentialProperties(t *testing.T) {
192	t.Run("Sequential permission requests with persistent grants", func(t *testing.T) {
193		service := NewPermissionService("/tmp", false, []string{})
194
195		req1 := CreatePermissionRequest{
196			SessionID:   "session1",
197			ToolName:    "file_tool",
198			Description: "Read file",
199			Action:      "read",
200			Params:      map[string]string{"file": "test.txt"},
201			Path:        "/tmp/test.txt",
202		}
203
204		var result1 bool
205		var wg sync.WaitGroup
206		wg.Add(1)
207
208		events := service.Subscribe(t.Context())
209
210		go func() {
211			defer wg.Done()
212			result1, _ = service.Request(t.Context(), req1)
213		}()
214
215		var permissionReq PermissionRequest
216		event := <-events
217
218		permissionReq = event.Payload
219		service.GrantPersistent(permissionReq)
220
221		wg.Wait()
222		assert.True(t, result1, "First request should be granted")
223
224		// Second identical request should be automatically approved due to persistent permission
225		req2 := CreatePermissionRequest{
226			SessionID:   "session1",
227			ToolName:    "file_tool",
228			Description: "Read file again",
229			Action:      "read",
230			Params:      map[string]string{"file": "test.txt"},
231			Path:        "/tmp/test.txt",
232		}
233		result2, err := service.Request(t.Context(), req2)
234		require.NoError(t, err)
235		assert.True(t, result2, "Second request should be auto-approved")
236	})
237	t.Run("Sequential requests with temporary grants", func(t *testing.T) {
238		service := NewPermissionService("/tmp", false, []string{})
239
240		req := CreatePermissionRequest{
241			SessionID:   "session2",
242			ToolName:    "file_tool",
243			Description: "Write file",
244			Action:      "write",
245			Params:      map[string]string{"file": "test.txt"},
246			Path:        "/tmp/test.txt",
247		}
248
249		events := service.Subscribe(t.Context())
250		var result1 bool
251		var wg sync.WaitGroup
252
253		wg.Go(func() {
254			result1, _ = service.Request(t.Context(), req)
255		})
256
257		var permissionReq PermissionRequest
258		event := <-events
259		permissionReq = event.Payload
260
261		service.Grant(permissionReq)
262		wg.Wait()
263		assert.True(t, result1, "First request should be granted")
264
265		var result2 bool
266
267		wg.Go(func() {
268			result2, _ = service.Request(t.Context(), req)
269		})
270
271		event = <-events
272		permissionReq = event.Payload
273		service.Deny(permissionReq)
274		wg.Wait()
275		assert.False(t, result2, "Second request should be denied")
276	})
277	t.Run("Concurrent requests with different outcomes", func(t *testing.T) {
278		service := NewPermissionService("/tmp", false, []string{})
279
280		events := service.Subscribe(t.Context())
281
282		var wg sync.WaitGroup
283		results := make([]bool, 3)
284
285		requests := []CreatePermissionRequest{
286			{
287				SessionID:   "concurrent1",
288				ToolName:    "tool1",
289				Action:      "action1",
290				Path:        "/tmp/file1.txt",
291				Description: "First concurrent request",
292			},
293			{
294				SessionID:   "concurrent2",
295				ToolName:    "tool2",
296				Action:      "action2",
297				Path:        "/tmp/file2.txt",
298				Description: "Second concurrent request",
299			},
300			{
301				SessionID:   "concurrent3",
302				ToolName:    "tool3",
303				Action:      "action3",
304				Path:        "/tmp/file3.txt",
305				Description: "Third concurrent request",
306			},
307		}
308
309		for i, req := range requests {
310			wg.Add(1)
311			go func(index int, request CreatePermissionRequest) {
312				defer wg.Done()
313				result, _ := service.Request(t.Context(), request)
314				results[index] = result
315			}(i, req)
316		}
317
318		for range 3 {
319			event := <-events
320			switch event.Payload.ToolName {
321			case "tool1":
322				service.Grant(event.Payload)
323			case "tool2":
324				service.GrantPersistent(event.Payload)
325			case "tool3":
326				service.Deny(event.Payload)
327			}
328		}
329		wg.Wait()
330		grantedCount := 0
331		for _, result := range results {
332			if result {
333				grantedCount++
334			}
335		}
336
337		assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied")
338		secondReq := requests[1]
339		secondReq.Description = "Repeat of second request"
340		result, err := service.Request(t.Context(), secondReq)
341		require.NoError(t, err)
342		assert.True(t, result, "Repeated request should be auto-approved due to persistent permission")
343	})
344}