permission_test.go

  1package permission
  2
  3import (
  4	"sync"
  5	"testing"
  6
  7	"github.com/stretchr/testify/assert"
  8	"github.com/stretchr/testify/require"
  9)
 10
 11func TestPermissionService_AllowedCommands(t *testing.T) {
 12	tests := []struct {
 13		name         string
 14		allowedTools []string
 15		toolName     string
 16		action       string
 17		expected     bool
 18	}{
 19		{
 20			name:         "tool in allowlist",
 21			allowedTools: []string{"bash", "view"},
 22			toolName:     "bash",
 23			action:       "execute",
 24			expected:     true,
 25		},
 26		{
 27			name:         "tool:action in allowlist",
 28			allowedTools: []string{"bash:execute", "edit:create"},
 29			toolName:     "bash",
 30			action:       "execute",
 31			expected:     true,
 32		},
 33		{
 34			name:         "tool not in allowlist",
 35			allowedTools: []string{"view", "ls"},
 36			toolName:     "bash",
 37			action:       "execute",
 38			expected:     false,
 39		},
 40		{
 41			name:         "tool:action not in allowlist",
 42			allowedTools: []string{"bash:read", "edit:create"},
 43			toolName:     "bash",
 44			action:       "execute",
 45			expected:     false,
 46		},
 47		{
 48			name:         "empty allowlist",
 49			allowedTools: []string{},
 50			toolName:     "bash",
 51			action:       "execute",
 52			expected:     false,
 53		},
 54	}
 55
 56	for _, tt := range tests {
 57		t.Run(tt.name, func(t *testing.T) {
 58			service := NewPermissionService("/tmp", false, tt.allowedTools)
 59
 60			// Create a channel to capture the permission request
 61			// Since we're testing the allowlist logic, we need to simulate the request
 62			ps := service.(*permissionService)
 63
 64			// Test the allowlist logic directly
 65			commandKey := tt.toolName + ":" + tt.action
 66			allowed := false
 67			for _, cmd := range ps.allowedTools {
 68				if cmd == commandKey || cmd == tt.toolName {
 69					allowed = true
 70					break
 71				}
 72			}
 73
 74			if allowed != tt.expected {
 75				t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v",
 76					tt.expected, allowed, tt.toolName, tt.action, tt.allowedTools)
 77			}
 78		})
 79	}
 80}
 81
 82func TestPermissionService_SkipMode(t *testing.T) {
 83	service := NewPermissionService("/tmp", true, []string{})
 84
 85	result, err := service.Request(t.Context(), CreatePermissionRequest{
 86		SessionID:   "test-session",
 87		ToolName:    "bash",
 88		Action:      "execute",
 89		Description: "test command",
 90		Path:        "/tmp",
 91	})
 92	if err != nil {
 93		t.Errorf("unexpected error: %v", err)
 94	}
 95	if !result {
 96		t.Error("expected permission to be granted in skip mode")
 97	}
 98}
 99
100func TestPermissionService_HookApproval(t *testing.T) {
101	t.Parallel()
102
103	t.Run("matching tool call ID short-circuits the prompt", func(t *testing.T) {
104		t.Parallel()
105		service := NewPermissionService("/tmp", false, nil)
106
107		ctx := WithHookApproval(t.Context(), "call-42")
108		granted, err := service.Request(ctx, CreatePermissionRequest{
109			SessionID:   "s1",
110			ToolCallID:  "call-42",
111			ToolName:    "bash",
112			Action:      "execute",
113			Description: "hook-approved command",
114			Path:        "/tmp",
115		})
116		require.NoError(t, err)
117		assert.True(t, granted, "hook-approved call should bypass the prompt")
118	})
119
120	t.Run("approval is scoped to the stamped tool call ID", func(t *testing.T) {
121		t.Parallel()
122		service := NewPermissionService("/tmp", false, nil)
123
124		// Stamp for call-42, ask for a different call ID — must not leak.
125		ctx := WithHookApproval(t.Context(), "call-42")
126
127		// Kick off a real request that will need a subscriber to resolve it.
128		events := service.Subscribe(t.Context())
129		var (
130			wg      sync.WaitGroup
131			granted bool
132			err     error
133		)
134		wg.Go(func() {
135			granted, err = service.Request(ctx, CreatePermissionRequest{
136				SessionID:   "s1",
137				ToolCallID:  "call-other",
138				ToolName:    "bash",
139				Action:      "execute",
140				Description: "unrelated call",
141				Path:        "/tmp",
142			})
143		})
144
145		// Confirm the service published a real request (i.e. didn't bypass).
146		event := <-events
147		service.Deny(event.Payload)
148		wg.Wait()
149		require.NoError(t, err)
150		assert.False(t, granted, "stamped approval must not apply to a different tool call")
151	})
152
153	t.Run("notifies subscribers that permission was granted", func(t *testing.T) {
154		t.Parallel()
155		service := NewPermissionService("/tmp", false, nil)
156
157		notifications := service.SubscribeNotifications(t.Context())
158
159		ctx := WithHookApproval(t.Context(), "call-99")
160		granted, err := service.Request(ctx, CreatePermissionRequest{
161			SessionID:  "s1",
162			ToolCallID: "call-99",
163			ToolName:   "view",
164			Action:     "read",
165			Path:       "/tmp",
166		})
167		require.NoError(t, err)
168		assert.True(t, granted)
169
170		event := <-notifications
171		assert.Equal(t, "call-99", event.Payload.ToolCallID)
172		assert.True(t, event.Payload.Granted, "subscribers should see a granted notification")
173	})
174}
175
176func TestPermissionService_SequentialProperties(t *testing.T) {
177	t.Run("Sequential permission requests with persistent grants", func(t *testing.T) {
178		service := NewPermissionService("/tmp", false, []string{})
179
180		req1 := CreatePermissionRequest{
181			SessionID:   "session1",
182			ToolName:    "file_tool",
183			Description: "Read file",
184			Action:      "read",
185			Params:      map[string]string{"file": "test.txt"},
186			Path:        "/tmp/test.txt",
187		}
188
189		var result1 bool
190		var wg sync.WaitGroup
191		wg.Add(1)
192
193		events := service.Subscribe(t.Context())
194
195		go func() {
196			defer wg.Done()
197			result1, _ = service.Request(t.Context(), req1)
198		}()
199
200		var permissionReq PermissionRequest
201		event := <-events
202
203		permissionReq = event.Payload
204		service.GrantPersistent(permissionReq)
205
206		wg.Wait()
207		assert.True(t, result1, "First request should be granted")
208
209		// Second identical request should be automatically approved due to persistent permission
210		req2 := CreatePermissionRequest{
211			SessionID:   "session1",
212			ToolName:    "file_tool",
213			Description: "Read file again",
214			Action:      "read",
215			Params:      map[string]string{"file": "test.txt"},
216			Path:        "/tmp/test.txt",
217		}
218		result2, err := service.Request(t.Context(), req2)
219		require.NoError(t, err)
220		assert.True(t, result2, "Second request should be auto-approved")
221	})
222	t.Run("Sequential requests with temporary grants", func(t *testing.T) {
223		service := NewPermissionService("/tmp", false, []string{})
224
225		req := CreatePermissionRequest{
226			SessionID:   "session2",
227			ToolName:    "file_tool",
228			Description: "Write file",
229			Action:      "write",
230			Params:      map[string]string{"file": "test.txt"},
231			Path:        "/tmp/test.txt",
232		}
233
234		events := service.Subscribe(t.Context())
235		var result1 bool
236		var wg sync.WaitGroup
237
238		wg.Go(func() {
239			result1, _ = service.Request(t.Context(), req)
240		})
241
242		var permissionReq PermissionRequest
243		event := <-events
244		permissionReq = event.Payload
245
246		service.Grant(permissionReq)
247		wg.Wait()
248		assert.True(t, result1, "First request should be granted")
249
250		var result2 bool
251
252		wg.Go(func() {
253			result2, _ = service.Request(t.Context(), req)
254		})
255
256		event = <-events
257		permissionReq = event.Payload
258		service.Deny(permissionReq)
259		wg.Wait()
260		assert.False(t, result2, "Second request should be denied")
261	})
262	t.Run("Concurrent requests with different outcomes", func(t *testing.T) {
263		service := NewPermissionService("/tmp", false, []string{})
264
265		events := service.Subscribe(t.Context())
266
267		var wg sync.WaitGroup
268		results := make([]bool, 3)
269
270		requests := []CreatePermissionRequest{
271			{
272				SessionID:   "concurrent1",
273				ToolName:    "tool1",
274				Action:      "action1",
275				Path:        "/tmp/file1.txt",
276				Description: "First concurrent request",
277			},
278			{
279				SessionID:   "concurrent2",
280				ToolName:    "tool2",
281				Action:      "action2",
282				Path:        "/tmp/file2.txt",
283				Description: "Second concurrent request",
284			},
285			{
286				SessionID:   "concurrent3",
287				ToolName:    "tool3",
288				Action:      "action3",
289				Path:        "/tmp/file3.txt",
290				Description: "Third concurrent request",
291			},
292		}
293
294		for i, req := range requests {
295			wg.Add(1)
296			go func(index int, request CreatePermissionRequest) {
297				defer wg.Done()
298				result, _ := service.Request(t.Context(), request)
299				results[index] = result
300			}(i, req)
301		}
302
303		for range 3 {
304			event := <-events
305			switch event.Payload.ToolName {
306			case "tool1":
307				service.Grant(event.Payload)
308			case "tool2":
309				service.GrantPersistent(event.Payload)
310			case "tool3":
311				service.Deny(event.Payload)
312			}
313		}
314		wg.Wait()
315		grantedCount := 0
316		for _, result := range results {
317			if result {
318				grantedCount++
319			}
320		}
321
322		assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied")
323		secondReq := requests[1]
324		secondReq.Description = "Repeat of second request"
325		result, err := service.Request(t.Context(), secondReq)
326		require.NoError(t, err)
327		assert.True(t, result, "Repeated request should be auto-approved due to persistent permission")
328	})
329}