diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 6619ce37b05576d049c6ae402d9d946c6affca1f..3fe475bab9fa2245067bac70ce689b2942a3747b 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,6 +6,7 @@ import ( "path/filepath" "slices" "sync" + "sync/atomic" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/pubsub" @@ -90,7 +91,7 @@ type permissionService struct { pendingRequests *csync.Map[string, chan bool] autoApproveSessions map[string]bool autoApproveSessionsMu sync.RWMutex - skip bool + skip atomic.Bool allowedTools []string // used to make sure we only process one request at a time @@ -159,7 +160,7 @@ func (s *permissionService) Deny(permission PermissionRequest) { } func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) { - if s.skip { + if s.skip.Load() { return true, nil } @@ -268,22 +269,23 @@ func (s *permissionService) SubscribeNotifications(ctx context.Context) <-chan p } func (s *permissionService) SetSkipRequests(skip bool) { - s.skip = skip + s.skip.Store(skip) } func (s *permissionService) SkipRequests() bool { - return s.skip + return s.skip.Load() } func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service { - return &permissionService{ + svc := &permissionService{ Broker: pubsub.NewBroker[PermissionRequest](), notificationBroker: pubsub.NewBroker[PermissionNotification](), workingDir: workingDir, sessionPermissions: csync.NewMap[PermissionKey, bool](), autoApproveSessions: make(map[string]bool), - skip: skip, allowedTools: allowedTools, pendingRequests: csync.NewMap[string, chan bool](), } + svc.skip.Store(skip) + return svc } diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index 34b06cfe58c4f0e86d23780aa7b9a4b14e51be1a..de08f6beae901172dd3c821a9ff7e544cbc7c6c5 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -79,6 +79,21 @@ func TestPermissionService_AllowedCommands(t *testing.T) { } } +func TestSkipRace(t *testing.T) { + svc := NewPermissionService("/tmp", false, nil) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + svc.SetSkipRequests(true) + }() + go func() { + defer wg.Done() + svc.SkipRequests() + }() + wg.Wait() +} + func TestPermissionService_SkipMode(t *testing.T) { service := NewPermissionService("/tmp", true, []string{})