diff --git a/README.md b/README.md index 5fed716c8c6bf437e75ca65401c15e5be64441d5..ccc85d1e8605426f7c79d858c28cb06d83203ece 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ # Crush
diff --git a/internal/app/app.go b/internal/app/app.go index f636395e58d65c100b5f68e31c704f2189bcf995..f3362c7276389b6669d6c9977d3565f482a44062 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -12,6 +12,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -37,8 +38,7 @@ type App struct { clientsMutex sync.RWMutex - watcherCancelFuncs []context.CancelFunc - cancelFuncsMutex sync.Mutex + watcherCancelFuncs *csync.Slice[context.CancelFunc] lspWatcherWG sync.WaitGroup config *config.Config @@ -76,6 +76,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { config: cfg, + watcherCancelFuncs: csync.NewSlice[context.CancelFunc](), + events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, tuiWG: &sync.WaitGroup{}, @@ -306,11 +308,9 @@ func (app *App) Shutdown() { app.CoderAgent.CancelAll() } - app.cancelFuncsMutex.Lock() - for _, cancel := range app.watcherCancelFuncs { + for cancel := range app.watcherCancelFuncs.Seq() { cancel() } - app.cancelFuncsMutex.Unlock() // Wait for all LSP watchers to finish. app.lspWatcherWG.Wait() diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 946a373e5a7a69dc78fcbcc894629ecf3e9485ac..afe76a68460d262a3f57f214ad3c0c153ddbd807 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -63,9 +63,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient) // Store the cancel function to be called during cleanup. - app.cancelFuncsMutex.Lock() - app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc) - app.cancelFuncsMutex.Unlock() + app.watcherCancelFuncs.Append(cancelFunc) // Add to map with mutex protection before starting goroutine app.clientsMutex.Lock() diff --git a/internal/config/config.go b/internal/config/config.go index 9a0da2a376abc88c5e584d7d39744da6f1890ce3..0f9fc99b5ce7677b0009933c447c0f7959825501 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -270,7 +270,7 @@ func (c *Config) WorkingDir() string { func (c *Config) EnabledProviders() []ProviderConfig { var enabled []ProviderConfig - for _, p := range c.Providers.Seq2() { + for p := range c.Providers.Seq() { if !p.Disable { enabled = append(enabled, p) } diff --git a/internal/csync/maps.go b/internal/csync/maps.go index 45e426630a4e50b45125d41dcca54d4e183b4f6f..108c8a4cbb6f855687d6117b1764b85e27279bc9 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -56,6 +56,15 @@ func (m *Map[K, V]) Len() int { return len(m.inner) } +// Take gets an item and then deletes it. +func (m *Map[K, V]) Take(key K) (V, bool) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.inner[key] + delete(m.inner, key) + return v, ok +} + // Seq2 returns an iter.Seq2 that yields key-value pairs from the map. func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { dst := make(map[K]V) @@ -71,6 +80,17 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { } } +// Seq returns an iter.Seq that yields values from the map. +func (m *Map[K, V]) Seq() iter.Seq[V] { + return func(yield func(V) bool) { + for _, v := range m.Seq2() { + if !yield(v) { + return + } + } + } +} + var ( _ json.Unmarshaler = &Map[string, any]{} _ json.Marshaler = &Map[string, any]{} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 73e6f1db245231e9fad82103366d96a326acc4f6..2882b30cb6ab8f71f969f03687766ca190467c0d 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -110,6 +110,72 @@ func TestMap_Len(t *testing.T) { assert.Equal(t, 0, m.Len()) } +func TestMap_Take(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + _, exists := m.Get("key1") + assert.False(t, exists) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) +} + +func TestMap_Take_NonexistentKey(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 1, m.Len()) + + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Take_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Take_SameKeyTwice(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 0, m.Len()) + + value, ok = m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + func TestMap_Seq2(t *testing.T) { t.Parallel() @@ -158,6 +224,57 @@ func TestMap_Seq2_EmptyMap(t *testing.T) { assert.Equal(t, 0, count) } +func TestMap_Seq(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := make([]int, 0) + for v := range m.Seq() { + collected = append(collected, v) + } + + assert.Equal(t, 3, len(collected)) + assert.Contains(t, collected, 1) + assert.Contains(t, collected, 2) + assert.Contains(t, collected, 3) +} + +func TestMap_Seq_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq() { + count++ + } + + assert.Equal(t, 0, count) +} + func TestMap_MarshalJSON(t *testing.T) { t.Parallel() @@ -371,6 +488,82 @@ func TestMap_ConcurrentSeq2(t *testing.T) { wg.Wait() } +func TestMap_ConcurrentSeq(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + values := make(map[int]bool) + for v := range m.Seq() { + values[v] = true + count++ + } + assert.Equal(t, 100, count) + for i := range 100 { + assert.True(t, values[i*2]) + } + }() + } + + wg.Wait() +} + +func TestMap_ConcurrentTake(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numItems = 1000 + + for i := range numItems { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numWorkers = 10 + taken := make([][]int, numWorkers) + + wg.Add(numWorkers) + for i := range numWorkers { + go func(workerID int) { + defer wg.Done() + taken[workerID] = make([]int, 0) + for j := workerID; j < numItems; j += numWorkers { + if value, ok := m.Take(j); ok { + taken[workerID] = append(taken[workerID], value) + } + } + }(i) + } + + wg.Wait() + + assert.Equal(t, 0, m.Len()) + + allTaken := make(map[int]bool) + for _, workerTaken := range taken { + for _, value := range workerTaken { + assert.False(t, allTaken[value], "Value %d was taken multiple times", value) + allTaken[value] = true + } + } + + assert.Equal(t, numItems, len(allTaken)) + for i := range numItems { + assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2) + } +} + func TestMap_TypeSafety(t *testing.T) { t.Parallel() @@ -431,6 +624,38 @@ func BenchmarkMap_Seq2(b *testing.B) { } } +func BenchmarkMap_Seq(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq() { + } + } +} + +func BenchmarkMap_Take(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + for i := 0; b.Loop(); i++ { + key := i % 1000 + m.Take(key) + if i%1000 == 999 { + b.StopTimer() + for j := range 1000 { + m.Set(j, j*2) + } + b.StartTimer() + } + } +} + func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { m := NewMap[int, int]() for i := range 1000 { diff --git a/internal/csync/slices.go b/internal/csync/slices.go index 176e53d13585517df72b28969a9ced54383706f9..3913a054c166c2bd29b3fafb7e6a0fa1998463a8 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -59,10 +59,10 @@ func NewSliceFrom[T any](s []T) *Slice[T] { } // Append adds an element to the end of the slice. -func (s *Slice[T]) Append(item T) { +func (s *Slice[T]) Append(items ...T) { s.mu.Lock() defer s.mu.Unlock() - s.inner = append(s.inner, item) + s.inner = append(s.inner, items...) } // Prepend adds an element to the beginning of the slice. @@ -112,6 +112,15 @@ func (s *Slice[T]) Len() int { return len(s.inner) } +// Slice returns a copy of the underlying slice. +func (s *Slice[T]) Slice() []T { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]T, len(s.inner)) + copy(result, s.inner) + return result +} + // SetSlice replaces the entire slice with a new one. func (s *Slice[T]) SetSlice(items []T) { s.mu.Lock() diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index ad4d7b408deb55ad68bc62d5327096573df8f8b6..d86b02537fe0534ed60de0e314e1e1d9b7b866b6 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -1,7 +1,6 @@ package csync import ( - "slices" "sync" "sync/atomic" "testing" @@ -146,7 +145,7 @@ func TestSlice(t *testing.T) { assert.Equal(t, 4, s.Len()) expected := []int{1, 2, 4, 5} - actual := slices.Collect(s.Seq()) + actual := s.Slice() assert.Equal(t, expected, actual) // Delete out of bounds @@ -204,7 +203,7 @@ func TestSlice(t *testing.T) { s.SetSlice(newItems) assert.Equal(t, 3, s.Len()) - assert.Equal(t, newItems, slices.Collect(s.Seq())) + assert.Equal(t, newItems, s.Slice()) // Verify it's a copy newItems[0] = 999 @@ -225,7 +224,7 @@ func TestSlice(t *testing.T) { original := []int{1, 2, 3} s := NewSliceFrom(original) - copy := slices.Collect(s.Seq()) + copy := s.Slice() assert.Equal(t, original, copy) // Verify it's a copy diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 2c3876ccac9ed028b1714ed96b0c6de0cce007c9..17a67f810b335f1dad105321a0bb0a8b354c9bfc 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "log/slog" "slices" "strings" - "sync" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -78,7 +77,7 @@ type agent struct { summarizeProvider provider.Provider summarizeProviderID string - activeRequests sync.Map + activeRequests *csync.Map[string, context.CancelFunc] } var agentPromptMap = map[string]prompt.PromptID{ @@ -222,7 +221,7 @@ func NewAgent( titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, + activeRequests: csync.NewMap[string, context.CancelFunc](), tools: csync.NewLazySlice(toolFn), }, nil } @@ -233,38 +232,31 @@ func (a *agent) Model() catwalk.Model { func (a *agent) Cancel(sessionID string) { // Cancel regular requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Request cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { + slog.Info("Request cancellation initiated", "session_id", sessionID) + cancel() } // Also check for summarize requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Summarize cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { + slog.Info("Summarize cancellation initiated", "session_id", sessionID) + cancel() } } func (a *agent) IsBusy() bool { - busy := false - a.activeRequests.Range(func(key, value any) bool { - if cancelFunc, ok := value.(context.CancelFunc); ok { - if cancelFunc != nil { - busy = true - return false - } + var busy bool + for cancelFunc := range a.activeRequests.Seq() { + if cancelFunc != nil { + busy = true + break } - return true - }) + } return busy } func (a *agent) IsSessionBusy(sessionID string) bool { - _, busy := a.activeRequests.Load(sessionID) + _, busy := a.activeRequests.Get(sessionID) return busy } @@ -335,7 +327,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Store(sessionID, cancel) + a.activeRequests.Set(sessionID, cancel) go func() { slog.Debug("Request started", "sessionID", sessionID) defer log.RecoverPanic("agent.Run", func() { @@ -350,7 +342,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac slog.Error(result.Error.Error()) } slog.Debug("Request completed", "sessionID", sessionID) - a.activeRequests.Delete(sessionID) + a.activeRequests.Del(sessionID) cancel() a.Publish(pubsub.CreatedEvent, result) events <- result @@ -682,10 +674,10 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { summarizeCtx, cancel := context.WithCancel(ctx) // Store the cancel function in activeRequests to allow cancellation - a.activeRequests.Store(sessionID+"-summarize", cancel) + a.activeRequests.Set(sessionID+"-summarize", cancel) go func() { - defer a.activeRequests.Delete(sessionID + "-summarize") + defer a.activeRequests.Del(sessionID + "-summarize") defer cancel() event := AgentEvent{ Type: AgentEventTypeSummarize, @@ -850,10 +842,9 @@ func (a *agent) CancelAll() { if !a.IsBusy() { return } - a.activeRequests.Range(func(key, value any) bool { - a.Cancel(key.(string)) // key is sessionID - return true - }) + for key := range a.activeRequests.Seq2() { + a.Cancel(key) // key is sessionID + } timeout := time.After(5 * time.Second) for a.IsBusy() { @@ -907,7 +898,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers.Seq2() { + for p := range cfg.Providers.Seq() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 05b4fada88973608b94eb840a18d65efae70fccf..e17a5527fb46979a8cd056473b3bcd184c014d60 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" + "slices" "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/permission" @@ -196,9 +198,8 @@ func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *confi } func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { - var mu sync.Mutex var wg sync.WaitGroup - var result []tools.BaseTool + result := csync.NewSlice[tools.BaseTool]() for name, m := range cfg.MCP { if m.Disabled { slog.Debug("skipping disabled mcp", "name", name) @@ -219,9 +220,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPHttp: c, err := client.NewStreamableHttpClient( m.URL, @@ -231,9 +230,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPSse: c, err := client.NewSSEMCPClient( m.URL, @@ -243,12 +240,10 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) } }(name, m) } wg.Wait() - return result + return slices.Collect(result.Seq()) } diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 4a2661bb9f663d9f93cf0371ac5d71dd513392c7..8c87482a71679f5bc682e6fdd8c1f5a03b89c184 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" ) @@ -74,8 +75,7 @@ func processContextPaths(workDir string, paths []string) string { ) // Track processed files to avoid duplicates - processedFiles := make(map[string]bool) - var processedMutex sync.Mutex + processedFiles := csync.NewMap[string, bool]() for _, path := range paths { wg.Add(1) @@ -106,14 +106,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(path) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) if result := processFile(path); result != "" { resultCh <- result } @@ -126,14 +120,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) result := processFile(fullPath) if result != "" { resultCh <- result diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 58b6d41decd1f551c4a474a6921e075b54e75a6e..6173d6e18e046345cc097052f6a06ff44b3e1e61 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -12,6 +12,7 @@ import ( "github.com/bmatcuk/doublestar/v4" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -25,8 +26,7 @@ type WorkspaceWatcher struct { workspacePath string debounceTime time.Duration - debounceMap map[string]*time.Timer - debounceMu sync.Mutex + debounceMap *csync.Map[string, *time.Timer] // File watchers registered by the server registrations []protocol.FileSystemWatcher @@ -46,7 +46,7 @@ func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher { name: name, client: client, debounceTime: 300 * time.Millisecond, - debounceMap: make(map[string]*time.Timer), + debounceMap: csync.NewMap[string, *time.Timer](), registrations: []protocol.FileSystemWatcher{}, } } @@ -639,26 +639,21 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt // debounceHandleFileEvent handles file events with debouncing to reduce notifications func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) { - w.debounceMu.Lock() - defer w.debounceMu.Unlock() - // Create a unique key based on URI and change type key := fmt.Sprintf("%s:%d", uri, changeType) // Cancel existing timer if any - if timer, exists := w.debounceMap[key]; exists { + if timer, exists := w.debounceMap.Get(key); exists { timer.Stop() } // Create new timer - w.debounceMap[key] = time.AfterFunc(w.debounceTime, func() { + w.debounceMap.Set(key, time.AfterFunc(w.debounceTime, func() { w.handleFileEvent(ctx, uri, changeType) // Cleanup timer after execution - w.debounceMu.Lock() - delete(w.debounceMap, key) - w.debounceMu.Unlock() - }) + w.debounceMap.Del(key) + })) } // handleFileEvent sends file change notifications diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 4a2b70d32e7d4a5387f479a2ccd97a06fe2e7ba4..476f33598feea326c42630b1ad54e012fc867bf4 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -7,6 +7,7 @@ import ( "slices" "sync" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) @@ -57,7 +58,7 @@ type permissionService struct { workingDir string sessionPermissions []PermissionRequest sessionPermissionsMu sync.RWMutex - pendingRequests sync.Map + pendingRequests *csync.Map[string, chan bool] autoApproveSessions map[string]bool autoApproveSessionsMu sync.RWMutex skip bool @@ -73,9 +74,9 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { ToolCallID: permission.ToolCallID, Granted: true, }) - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } s.sessionPermissionsMu.Lock() @@ -92,9 +93,9 @@ func (s *permissionService) Grant(permission PermissionRequest) { ToolCallID: permission.ToolCallID, Granted: true, }) - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } if s.activeRequest != nil && s.activeRequest.ID == permission.ID { @@ -108,9 +109,9 @@ func (s *permissionService) Deny(permission PermissionRequest) { Granted: false, Denied: true, }) - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- false + respCh <- false } if s.activeRequest != nil && s.activeRequest.ID == permission.ID { @@ -180,8 +181,8 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { s.activeRequest = &permission respCh := make(chan bool, 1) - s.pendingRequests.Store(permission.ID, respCh) - defer s.pendingRequests.Delete(permission.ID) + s.pendingRequests.Set(permission.ID, respCh) + defer s.pendingRequests.Del(permission.ID) // Publish the request s.Publish(pubsub.CreatedEvent, permission) @@ -208,5 +209,6 @@ func NewPermissionService(workingDir string, skip bool, allowedTools []string) S autoApproveSessions: make(map[string]bool), skip: skip, allowedTools: allowedTools, + pendingRequests: csync.NewMap[string, chan bool](), } } diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 1aa239bdc15cec6898a4cba1e4dc7a867b5e4ce0..1f5fd2a672e3d643efbed4ca35b08ed88c55d2eb 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "os" + "slices" "sort" "strings" - "sync" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -71,8 +72,7 @@ type sidebarCmp struct { lspClients map[string]*lsp.Client compactMode bool history history.Service - // Using a sync map here because we might receive file history events concurrently - files sync.Map + files *csync.Map[string, SessionFile] } func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar { @@ -80,6 +80,7 @@ func New(history history.Service, lspClients map[string]*lsp.Client, compact boo lspClients: lspClients, history: history, compactMode: compact, + files: csync.NewMap[string, SessionFile](), } } @@ -90,9 +91,9 @@ func (m *sidebarCmp) Init() tea.Cmd { func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case SessionFilesMsg: - m.files = sync.Map{} + m.files = csync.NewMap[string, SessionFile]() for _, file := range msg.Files { - m.files.Store(file.FilePath, file) + m.files.Set(file.FilePath, file) } return m, nil @@ -178,31 +179,30 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te return func() tea.Msg { file := event.Payload found := false - m.files.Range(func(key, value any) bool { - existing := value.(SessionFile) - if existing.FilePath == file.Path { - if existing.History.latestVersion.Version < file.Version { - existing.History.latestVersion = file - } else if file.Version == 0 { - existing.History.initialVersion = file - } else { - // If the version is not greater than the latest, we ignore it - return true - } - before := existing.History.initialVersion.Content - after := existing.History.latestVersion.Content - path := existing.History.initialVersion.Path - cwd := config.Get().WorkingDir() - path = strings.TrimPrefix(path, cwd) - _, additions, deletions := diff.GenerateDiff(before, after, path) - existing.Additions = additions - existing.Deletions = deletions - m.files.Store(file.Path, existing) - found = true - return false + for existing := range m.files.Seq() { + if existing.FilePath != file.Path { + continue } - return true - }) + if existing.History.latestVersion.Version < file.Version { + existing.History.latestVersion = file + } else if file.Version == 0 { + existing.History.initialVersion = file + } else { + // If the version is not greater than the latest, we ignore it + continue + } + before := existing.History.initialVersion.Content + after := existing.History.latestVersion.Content + path := existing.History.initialVersion.Path + cwd := config.Get().WorkingDir() + path = strings.TrimPrefix(path, cwd) + _, additions, deletions := diff.GenerateDiff(before, after, path) + existing.Additions = additions + existing.Deletions = deletions + m.files.Set(file.Path, existing) + found = true + break + } if found { return nil } @@ -215,7 +215,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te Additions: 0, Deletions: 0, } - m.files.Store(file.Path, sf) + m.files.Set(file.Path, sf) return nil } } @@ -386,12 +386,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string { section := t.S().Subtle.Render("Modified Files") - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { content := lipgloss.JoinVertical( @@ -620,12 +615,7 @@ func (m *sidebarCmp) filesBlock() string { core.Section("Modified Files", m.getMaxWidth()), ) - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true // continue iterating - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { return lipgloss.JoinVertical( lipgloss.Left,