1package permission
2
3import (
4 "errors"
5 "path/filepath"
6 "slices"
7 "sync"
8
9 "github.com/charmbracelet/crush/internal/csync"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/google/uuid"
12)
13
14var ErrorPermissionDenied = errors.New("permission denied")
15
16type CreatePermissionRequest struct {
17 SessionID string `json:"session_id"`
18 ToolName string `json:"tool_name"`
19 Description string `json:"description"`
20 Action string `json:"action"`
21 Params any `json:"params"`
22 Path string `json:"path"`
23}
24
25type PermissionRequest struct {
26 ID string `json:"id"`
27 SessionID string `json:"session_id"`
28 ToolName string `json:"tool_name"`
29 Description string `json:"description"`
30 Action string `json:"action"`
31 Params any `json:"params"`
32 Path string `json:"path"`
33}
34
35type Service interface {
36 pubsub.Suscriber[PermissionRequest]
37 GrantPersistent(permission PermissionRequest)
38 Grant(permission PermissionRequest)
39 Deny(permission PermissionRequest)
40 Request(opts CreatePermissionRequest) bool
41 AutoApproveSession(sessionID string)
42}
43
44type permissionService struct {
45 *pubsub.Broker[PermissionRequest]
46
47 workingDir string
48 sessionPermissions []PermissionRequest
49 sessionPermissionsMu sync.RWMutex
50 pendingRequests *csync.Map[string, chan bool]
51 autoApproveSessions []string
52 autoApproveSessionsMu sync.RWMutex
53 skip bool
54 allowedTools []string
55}
56
57func (s *permissionService) GrantPersistent(permission PermissionRequest) {
58 respCh, ok := s.pendingRequests.Get(permission.ID)
59 if ok {
60 respCh <- true
61 }
62
63 s.sessionPermissionsMu.Lock()
64 s.sessionPermissions = append(s.sessionPermissions, permission)
65 s.sessionPermissionsMu.Unlock()
66}
67
68func (s *permissionService) Grant(permission PermissionRequest) {
69 respCh, ok := s.pendingRequests.Get(permission.ID)
70 if ok {
71 respCh <- true
72 }
73}
74
75func (s *permissionService) Deny(permission PermissionRequest) {
76 respCh, ok := s.pendingRequests.Get(permission.ID)
77 if ok {
78 respCh <- false
79 }
80}
81
82func (s *permissionService) Request(opts CreatePermissionRequest) bool {
83 if s.skip {
84 return true
85 }
86
87 // Check if the tool/action combination is in the allowlist
88 commandKey := opts.ToolName + ":" + opts.Action
89 if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) {
90 return true
91 }
92
93 s.autoApproveSessionsMu.RLock()
94 autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
95 s.autoApproveSessionsMu.RUnlock()
96
97 if autoApprove {
98 return true
99 }
100
101 dir := filepath.Dir(opts.Path)
102 if dir == "." {
103 dir = s.workingDir
104 }
105 permission := PermissionRequest{
106 ID: uuid.New().String(),
107 Path: dir,
108 SessionID: opts.SessionID,
109 ToolName: opts.ToolName,
110 Description: opts.Description,
111 Action: opts.Action,
112 Params: opts.Params,
113 }
114
115 s.sessionPermissionsMu.RLock()
116 for _, p := range s.sessionPermissions {
117 if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
118 s.sessionPermissionsMu.RUnlock()
119 return true
120 }
121 }
122 s.sessionPermissionsMu.RUnlock()
123
124 respCh := make(chan bool, 1)
125
126 s.pendingRequests.Set(permission.ID, respCh)
127 defer s.pendingRequests.Del(permission.ID)
128
129 s.Publish(pubsub.CreatedEvent, permission)
130
131 // Wait for the response indefinitely
132 return <-respCh
133}
134
135func (s *permissionService) AutoApproveSession(sessionID string) {
136 s.autoApproveSessionsMu.Lock()
137 s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
138 s.autoApproveSessionsMu.Unlock()
139}
140
141func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service {
142 return &permissionService{
143 Broker: pubsub.NewBroker[PermissionRequest](),
144 workingDir: workingDir,
145 sessionPermissions: make([]PermissionRequest, 0),
146 skip: skip,
147 allowedTools: allowedTools,
148 pendingRequests: csync.NewMap[string, chan bool](),
149 }
150}