From b92d8800e3451a908bcb568bc97d9a3fb5a53232 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 25 Jul 2025 11:25:45 -0300 Subject: [PATCH 1/3] refactor: use csync.Map instead of sync.Map --- internal/config/config.go | 2 +- internal/csync/maps.go | 24 ++ internal/csync/maps_test.go | 225 ++++++++++++++++++ internal/llm/agent/agent.go | 54 ++--- internal/permission/permission.go | 20 +- .../tui/components/chat/sidebar/sidebar.go | 73 +++--- 6 files changed, 314 insertions(+), 84 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 9a0da2a376abc88c5e584d7d39744da6f1890ce3..0f9fc99b5ce7677b0009933c447c0f7959825501 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -270,7 +270,7 @@ func (c *Config) WorkingDir() string { func (c *Config) EnabledProviders() []ProviderConfig { var enabled []ProviderConfig - for _, p := range c.Providers.Seq2() { + for p := range c.Providers.Seq() { if !p.Disable { enabled = append(enabled, p) } diff --git a/internal/csync/maps.go b/internal/csync/maps.go index 45e426630a4e50b45125d41dcca54d4e183b4f6f..ddc735b2624500899ca25670f7934326dfd9bdf3 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -56,6 +56,15 @@ func (m *Map[K, V]) Len() int { return len(m.inner) } +// Take gets an item and then deletes it. +func (m *Map[K, V]) Take(key K) (V, bool) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.inner[key] + delete(m.inner, key) + return v, ok +} + // Seq2 returns an iter.Seq2 that yields key-value pairs from the map. func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { dst := make(map[K]V) @@ -71,6 +80,21 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { } } +// Seq returns an iter.Seq that yields values from the map. +func (m *Map[K, V]) Seq() iter.Seq[V] { + dst := make(map[K]V) + m.mu.RLock() + maps.Copy(dst, m.inner) + m.mu.RUnlock() + return func(yield func(V) bool) { + for _, v := range dst { + if !yield(v) { + return + } + } + } +} + var ( _ json.Unmarshaler = &Map[string, any]{} _ json.Marshaler = &Map[string, any]{} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 73e6f1db245231e9fad82103366d96a326acc4f6..2882b30cb6ab8f71f969f03687766ca190467c0d 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -110,6 +110,72 @@ func TestMap_Len(t *testing.T) { assert.Equal(t, 0, m.Len()) } +func TestMap_Take(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + _, exists := m.Get("key1") + assert.False(t, exists) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) +} + +func TestMap_Take_NonexistentKey(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 1, m.Len()) + + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Take_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Take_SameKeyTwice(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 0, m.Len()) + + value, ok = m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + func TestMap_Seq2(t *testing.T) { t.Parallel() @@ -158,6 +224,57 @@ func TestMap_Seq2_EmptyMap(t *testing.T) { assert.Equal(t, 0, count) } +func TestMap_Seq(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := make([]int, 0) + for v := range m.Seq() { + collected = append(collected, v) + } + + assert.Equal(t, 3, len(collected)) + assert.Contains(t, collected, 1) + assert.Contains(t, collected, 2) + assert.Contains(t, collected, 3) +} + +func TestMap_Seq_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq() { + count++ + } + + assert.Equal(t, 0, count) +} + func TestMap_MarshalJSON(t *testing.T) { t.Parallel() @@ -371,6 +488,82 @@ func TestMap_ConcurrentSeq2(t *testing.T) { wg.Wait() } +func TestMap_ConcurrentSeq(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + values := make(map[int]bool) + for v := range m.Seq() { + values[v] = true + count++ + } + assert.Equal(t, 100, count) + for i := range 100 { + assert.True(t, values[i*2]) + } + }() + } + + wg.Wait() +} + +func TestMap_ConcurrentTake(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numItems = 1000 + + for i := range numItems { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numWorkers = 10 + taken := make([][]int, numWorkers) + + wg.Add(numWorkers) + for i := range numWorkers { + go func(workerID int) { + defer wg.Done() + taken[workerID] = make([]int, 0) + for j := workerID; j < numItems; j += numWorkers { + if value, ok := m.Take(j); ok { + taken[workerID] = append(taken[workerID], value) + } + } + }(i) + } + + wg.Wait() + + assert.Equal(t, 0, m.Len()) + + allTaken := make(map[int]bool) + for _, workerTaken := range taken { + for _, value := range workerTaken { + assert.False(t, allTaken[value], "Value %d was taken multiple times", value) + allTaken[value] = true + } + } + + assert.Equal(t, numItems, len(allTaken)) + for i := range numItems { + assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2) + } +} + func TestMap_TypeSafety(t *testing.T) { t.Parallel() @@ -431,6 +624,38 @@ func BenchmarkMap_Seq2(b *testing.B) { } } +func BenchmarkMap_Seq(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq() { + } + } +} + +func BenchmarkMap_Take(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + for i := 0; b.Loop(); i++ { + key := i % 1000 + m.Take(key) + if i%1000 == 999 { + b.StopTimer() + for j := range 1000 { + m.Set(j, j*2) + } + b.StartTimer() + } + } +} + func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { m := NewMap[int, int]() for i := range 1000 { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 2c3876ccac9ed028b1714ed96b0c6de0cce007c9..67bc861bbe5b106fa134abbf492cad855d2d362a 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "log/slog" "slices" "strings" - "sync" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -78,7 +77,7 @@ type agent struct { summarizeProvider provider.Provider summarizeProviderID string - activeRequests sync.Map + activeRequests *csync.Map[string, context.CancelFunc] } var agentPromptMap = map[string]prompt.PromptID{ @@ -222,7 +221,7 @@ func NewAgent( titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, + activeRequests: csync.NewMap[string, context.CancelFunc](), tools: csync.NewLazySlice(toolFn), }, nil } @@ -233,38 +232,30 @@ func (a *agent) Model() catwalk.Model { func (a *agent) Cancel(sessionID string) { // Cancel regular requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Request cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { + slog.Info("Request cancellation initiated", "session_id", sessionID) + cancel() } // Also check for summarize requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Summarize cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { + slog.Info("Summarize cancellation initiated", "session_id", sessionID) + cancel() } } func (a *agent) IsBusy() bool { - busy := false - a.activeRequests.Range(func(key, value any) bool { - if cancelFunc, ok := value.(context.CancelFunc); ok { - if cancelFunc != nil { - busy = true - return false - } + var busy bool + for cancelFunc := range a.activeRequests.Seq() { + if cancelFunc != nil { + busy = true } - return true - }) + } return busy } func (a *agent) IsSessionBusy(sessionID string) bool { - _, busy := a.activeRequests.Load(sessionID) + _, busy := a.activeRequests.Get(sessionID) return busy } @@ -335,7 +326,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Store(sessionID, cancel) + a.activeRequests.Set(sessionID, cancel) go func() { slog.Debug("Request started", "sessionID", sessionID) defer log.RecoverPanic("agent.Run", func() { @@ -350,7 +341,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac slog.Error(result.Error.Error()) } slog.Debug("Request completed", "sessionID", sessionID) - a.activeRequests.Delete(sessionID) + a.activeRequests.Del(sessionID) cancel() a.Publish(pubsub.CreatedEvent, result) events <- result @@ -682,10 +673,10 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { summarizeCtx, cancel := context.WithCancel(ctx) // Store the cancel function in activeRequests to allow cancellation - a.activeRequests.Store(sessionID+"-summarize", cancel) + a.activeRequests.Set(sessionID+"-summarize", cancel) go func() { - defer a.activeRequests.Delete(sessionID + "-summarize") + defer a.activeRequests.Del(sessionID + "-summarize") defer cancel() event := AgentEvent{ Type: AgentEventTypeSummarize, @@ -850,10 +841,9 @@ func (a *agent) CancelAll() { if !a.IsBusy() { return } - a.activeRequests.Range(func(key, value any) bool { - a.Cancel(key.(string)) // key is sessionID - return true - }) + for key := range a.activeRequests.Seq2() { + a.Cancel(key) // key is sessionID + } timeout := time.After(5 * time.Second) for a.IsBusy() { @@ -907,7 +897,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers.Seq2() { + for p := range cfg.Providers.Seq() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/permission/permission.go b/internal/permission/permission.go index cd149a49890b54086bd52e562eed0d44f00c407e..c5d001075a8cd01a91cccee0afcd44f89a5d4bcc 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,6 +6,7 @@ import ( "slices" "sync" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) @@ -46,7 +47,7 @@ type permissionService struct { workingDir string sessionPermissions []PermissionRequest sessionPermissionsMu sync.RWMutex - pendingRequests sync.Map + pendingRequests *csync.Map[string, chan bool] autoApproveSessions []string autoApproveSessionsMu sync.RWMutex skip bool @@ -54,9 +55,9 @@ type permissionService struct { } func (s *permissionService) GrantPersistent(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } s.sessionPermissionsMu.Lock() @@ -65,16 +66,16 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { } func (s *permissionService) Grant(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } } func (s *permissionService) Deny(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- false + respCh <- false } } @@ -122,8 +123,8 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { respCh := make(chan bool, 1) - s.pendingRequests.Store(permission.ID, respCh) - defer s.pendingRequests.Delete(permission.ID) + s.pendingRequests.Set(permission.ID, respCh) + defer s.pendingRequests.Del(permission.ID) s.Publish(pubsub.CreatedEvent, permission) @@ -144,5 +145,6 @@ func NewPermissionService(workingDir string, skip bool, allowedTools []string) S sessionPermissions: make([]PermissionRequest, 0), skip: skip, allowedTools: allowedTools, + pendingRequests: csync.NewMap[string, chan bool](), } } diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 1aa239bdc15cec6898a4cba1e4dc7a867b5e4ce0..3ab2e9563420e5f3c6365bc71a66e35ab7b79f11 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "os" + "slices" "sort" "strings" - "sync" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -71,8 +72,7 @@ type sidebarCmp struct { lspClients map[string]*lsp.Client compactMode bool history history.Service - // Using a sync map here because we might receive file history events concurrently - files sync.Map + files *csync.Map[string, SessionFile] } func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar { @@ -80,6 +80,7 @@ func New(history history.Service, lspClients map[string]*lsp.Client, compact boo lspClients: lspClients, history: history, compactMode: compact, + files: csync.NewMap[string, SessionFile](), } } @@ -90,9 +91,9 @@ func (m *sidebarCmp) Init() tea.Cmd { func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case SessionFilesMsg: - m.files = sync.Map{} + m.files = csync.NewMap[string, SessionFile]() for _, file := range msg.Files { - m.files.Store(file.FilePath, file) + m.files.Set(file.FilePath, file) } return m, nil @@ -178,31 +179,29 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te return func() tea.Msg { file := event.Payload found := false - m.files.Range(func(key, value any) bool { - existing := value.(SessionFile) - if existing.FilePath == file.Path { - if existing.History.latestVersion.Version < file.Version { - existing.History.latestVersion = file - } else if file.Version == 0 { - existing.History.initialVersion = file - } else { - // If the version is not greater than the latest, we ignore it - return true - } - before := existing.History.initialVersion.Content - after := existing.History.latestVersion.Content - path := existing.History.initialVersion.Path - cwd := config.Get().WorkingDir() - path = strings.TrimPrefix(path, cwd) - _, additions, deletions := diff.GenerateDiff(before, after, path) - existing.Additions = additions - existing.Deletions = deletions - m.files.Store(file.Path, existing) - found = true - return false + for existing := range m.files.Seq() { + if existing.FilePath != file.Path { + continue } - return true - }) + if existing.History.latestVersion.Version < file.Version { + existing.History.latestVersion = file + } else if file.Version == 0 { + existing.History.initialVersion = file + } else { + // If the version is not greater than the latest, we ignore it + continue + } + before := existing.History.initialVersion.Content + after := existing.History.latestVersion.Content + path := existing.History.initialVersion.Path + cwd := config.Get().WorkingDir() + path = strings.TrimPrefix(path, cwd) + _, additions, deletions := diff.GenerateDiff(before, after, path) + existing.Additions = additions + existing.Deletions = deletions + m.files.Set(file.Path, existing) + found = true + } if found { return nil } @@ -215,7 +214,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te Additions: 0, Deletions: 0, } - m.files.Store(file.Path, sf) + m.files.Set(file.Path, sf) return nil } } @@ -386,12 +385,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string { section := t.S().Subtle.Render("Modified Files") - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { content := lipgloss.JoinVertical( @@ -620,12 +614,7 @@ func (m *sidebarCmp) filesBlock() string { core.Section("Modified Files", m.getMaxWidth()), ) - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true // continue iterating - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { return lipgloss.JoinVertical( lipgloss.Left, From 615c8d222f92812eb464ad1d9b0fa5f1c3a7a084 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 25 Jul 2025 11:40:04 -0300 Subject: [PATCH 2/3] fix: breaks --- internal/llm/agent/agent.go | 1 + internal/tui/components/chat/sidebar/sidebar.go | 1 + 2 files changed, 2 insertions(+) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 67bc861bbe5b106fa134abbf492cad855d2d362a..17a67f810b335f1dad105321a0bb0a8b354c9bfc 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -249,6 +249,7 @@ func (a *agent) IsBusy() bool { for cancelFunc := range a.activeRequests.Seq() { if cancelFunc != nil { busy = true + break } } return busy diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 3ab2e9563420e5f3c6365bc71a66e35ab7b79f11..1f5fd2a672e3d643efbed4ca35b08ed88c55d2eb 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -201,6 +201,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te existing.Deletions = deletions m.files.Set(file.Path, existing) found = true + break } if found { return nil From 61ea243e489bb519de123feab56986f15347ce4c Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 25 Jul 2025 12:07:58 -0300 Subject: [PATCH 3/3] feat: slices as well --- internal/app/app.go | 10 +- internal/app/lsp.go | 4 +- internal/csync/maps.go | 6 +- internal/csync/slices.go | 127 +++++++++++++++++++ internal/csync/slices_test.go | 209 ++++++++++++++++++++++++++++++++ internal/llm/agent/mcp-tools.go | 19 ++- internal/llm/prompt/prompt.go | 24 +--- internal/lsp/watcher/watcher.go | 19 ++- 8 files changed, 363 insertions(+), 55 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 50e117ea1ae272156dbd11baa1a5f157a74333f1..67bfd8f5d7f8cd6b8b54a354a426b7fa3b0b01bb 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -12,6 +12,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -37,8 +38,7 @@ type App struct { clientsMutex sync.RWMutex - watcherCancelFuncs []context.CancelFunc - cancelFuncsMutex sync.Mutex + watcherCancelFuncs *csync.Slice[context.CancelFunc] lspWatcherWG sync.WaitGroup config *config.Config @@ -76,6 +76,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { config: cfg, + watcherCancelFuncs: csync.NewSlice[context.CancelFunc](), + events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, tuiWG: &sync.WaitGroup{}, @@ -305,11 +307,9 @@ func (app *App) Shutdown() { app.CoderAgent.CancelAll() } - app.cancelFuncsMutex.Lock() - for _, cancel := range app.watcherCancelFuncs { + for cancel := range app.watcherCancelFuncs.Seq() { cancel() } - app.cancelFuncsMutex.Unlock() // Wait for all LSP watchers to finish. app.lspWatcherWG.Wait() diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 946a373e5a7a69dc78fcbcc894629ecf3e9485ac..afe76a68460d262a3f57f214ad3c0c153ddbd807 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -63,9 +63,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient) // Store the cancel function to be called during cleanup. - app.cancelFuncsMutex.Lock() - app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc) - app.cancelFuncsMutex.Unlock() + app.watcherCancelFuncs.Append(cancelFunc) // Add to map with mutex protection before starting goroutine app.clientsMutex.Lock() diff --git a/internal/csync/maps.go b/internal/csync/maps.go index ddc735b2624500899ca25670f7934326dfd9bdf3..108c8a4cbb6f855687d6117b1764b85e27279bc9 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -82,12 +82,8 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { // Seq returns an iter.Seq that yields values from the map. func (m *Map[K, V]) Seq() iter.Seq[V] { - dst := make(map[K]V) - m.mu.RLock() - maps.Copy(dst, m.inner) - m.mu.RUnlock() return func(yield func(V) bool) { - for _, v := range dst { + for _, v := range m.Seq2() { if !yield(v) { return } diff --git a/internal/csync/slices.go b/internal/csync/slices.go index be723655079ccc6b07f55c3237b706a17bb14d40..3913a054c166c2bd29b3fafb7e6a0fa1998463a8 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -2,6 +2,7 @@ package csync import ( "iter" + "slices" "sync" ) @@ -34,3 +35,129 @@ func (s *LazySlice[K]) Seq() iter.Seq[K] { } } } + +// Slice is a thread-safe slice implementation that provides concurrent access. +type Slice[T any] struct { + inner []T + mu sync.RWMutex +} + +// NewSlice creates a new thread-safe slice. +func NewSlice[T any]() *Slice[T] { + return &Slice[T]{ + inner: make([]T, 0), + } +} + +// NewSliceFrom creates a new thread-safe slice from an existing slice. +func NewSliceFrom[T any](s []T) *Slice[T] { + inner := make([]T, len(s)) + copy(inner, s) + return &Slice[T]{ + inner: inner, + } +} + +// Append adds an element to the end of the slice. +func (s *Slice[T]) Append(items ...T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = append(s.inner, items...) +} + +// Prepend adds an element to the beginning of the slice. +func (s *Slice[T]) Prepend(item T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = append([]T{item}, s.inner...) +} + +// Delete removes the element at the specified index. +func (s *Slice[T]) Delete(index int) bool { + s.mu.Lock() + defer s.mu.Unlock() + if index < 0 || index >= len(s.inner) { + return false + } + s.inner = slices.Delete(s.inner, index, index+1) + return true +} + +// Get returns the element at the specified index. +func (s *Slice[T]) Get(index int) (T, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + var zero T + if index < 0 || index >= len(s.inner) { + return zero, false + } + return s.inner[index], true +} + +// Set updates the element at the specified index. +func (s *Slice[T]) Set(index int, item T) bool { + s.mu.Lock() + defer s.mu.Unlock() + if index < 0 || index >= len(s.inner) { + return false + } + s.inner[index] = item + return true +} + +// Len returns the number of elements in the slice. +func (s *Slice[T]) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.inner) +} + +// Slice returns a copy of the underlying slice. +func (s *Slice[T]) Slice() []T { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]T, len(s.inner)) + copy(result, s.inner) + return result +} + +// SetSlice replaces the entire slice with a new one. +func (s *Slice[T]) SetSlice(items []T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = make([]T, len(items)) + copy(s.inner, items) +} + +// Clear removes all elements from the slice. +func (s *Slice[T]) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = s.inner[:0] +} + +// Seq returns an iterator that yields elements from the slice. +func (s *Slice[T]) Seq() iter.Seq[T] { + return func(yield func(T) bool) { + for _, v := range s.Seq2() { + if !yield(v) { + return + } + } + } +} + +// Seq2 returns an iterator that yields index-value pairs from the slice. +func (s *Slice[T]) Seq2() iter.Seq2[int, T] { + s.mu.RLock() + items := make([]T, len(s.inner)) + copy(items, s.inner) + s.mu.RUnlock() + return func(yield func(int, T) bool) { + for i, v := range items { + if !yield(i, v) { + return + } + } + } +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 731cb96f55dd24cae74f55c0ef8e97ebd28aacaa..d86b02537fe0534ed60de0e314e1e1d9b7b866b6 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -1,11 +1,13 @@ package csync import ( + "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLazySlice_Seq(t *testing.T) { @@ -85,3 +87,210 @@ func TestLazySlice_EarlyBreak(t *testing.T) { assert.Equal(t, []string{"a", "b"}, result) } + +func TestSlice(t *testing.T) { + t.Run("NewSlice", func(t *testing.T) { + s := NewSlice[int]() + assert.Equal(t, 0, s.Len()) + }) + + t.Run("NewSliceFrom", func(t *testing.T) { + original := []int{1, 2, 3} + s := NewSliceFrom(original) + assert.Equal(t, 3, s.Len()) + + // Verify it's a copy, not a reference + original[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 1, val) + }) + + t.Run("Append", func(t *testing.T) { + s := NewSlice[string]() + s.Append("hello") + s.Append("world") + + assert.Equal(t, 2, s.Len()) + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, "hello", val) + + val, ok = s.Get(1) + require.True(t, ok) + assert.Equal(t, "world", val) + }) + + t.Run("Prepend", func(t *testing.T) { + s := NewSlice[string]() + s.Append("world") + s.Prepend("hello") + + assert.Equal(t, 2, s.Len()) + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, "hello", val) + + val, ok = s.Get(1) + require.True(t, ok) + assert.Equal(t, "world", val) + }) + + t.Run("Delete", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3, 4, 5}) + + // Delete middle element + ok := s.Delete(2) + assert.True(t, ok) + assert.Equal(t, 4, s.Len()) + + expected := []int{1, 2, 4, 5} + actual := s.Slice() + assert.Equal(t, expected, actual) + + // Delete out of bounds + ok = s.Delete(10) + assert.False(t, ok) + assert.Equal(t, 4, s.Len()) + + // Delete negative index + ok = s.Delete(-1) + assert.False(t, ok) + assert.Equal(t, 4, s.Len()) + }) + + t.Run("Get", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + val, ok := s.Get(1) + require.True(t, ok) + assert.Equal(t, "b", val) + + // Out of bounds + _, ok = s.Get(10) + assert.False(t, ok) + + // Negative index + _, ok = s.Get(-1) + assert.False(t, ok) + }) + + t.Run("Set", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + ok := s.Set(1, "modified") + assert.True(t, ok) + + val, ok := s.Get(1) + require.True(t, ok) + assert.Equal(t, "modified", val) + + // Out of bounds + ok = s.Set(10, "invalid") + assert.False(t, ok) + + // Negative index + ok = s.Set(-1, "invalid") + assert.False(t, ok) + }) + + t.Run("SetSlice", func(t *testing.T) { + s := NewSlice[int]() + s.Append(1) + s.Append(2) + + newItems := []int{10, 20, 30} + s.SetSlice(newItems) + + assert.Equal(t, 3, s.Len()) + assert.Equal(t, newItems, s.Slice()) + + // Verify it's a copy + newItems[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 10, val) + }) + + t.Run("Clear", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3}) + assert.Equal(t, 3, s.Len()) + + s.Clear() + assert.Equal(t, 0, s.Len()) + }) + + t.Run("Slice", func(t *testing.T) { + original := []int{1, 2, 3} + s := NewSliceFrom(original) + + copy := s.Slice() + assert.Equal(t, original, copy) + + // Verify it's a copy + copy[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 1, val) + }) + + t.Run("Seq", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3}) + + var result []int + for v := range s.Seq() { + result = append(result, v) + } + + assert.Equal(t, []int{1, 2, 3}, result) + }) + + t.Run("SeqWithIndex", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + var indices []int + var values []string + for i, v := range s.Seq2() { + indices = append(indices, i) + values = append(values, v) + } + + assert.Equal(t, []int{0, 1, 2}, indices) + assert.Equal(t, []string{"a", "b", "c"}, values) + }) + + t.Run("ConcurrentAccess", func(t *testing.T) { + s := NewSlice[int]() + const numGoroutines = 100 + const itemsPerGoroutine = 10 + + var wg sync.WaitGroup + + // Concurrent appends + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(start int) { + defer wg.Done() + for j := 0; j < itemsPerGoroutine; j++ { + s.Append(start*itemsPerGoroutine + j) + } + }(i) + } + + // Concurrent reads + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerGoroutine; j++ { + s.Len() // Just read the length + } + }() + } + + wg.Wait() + + // Should have all items + assert.Equal(t, numGoroutines*itemsPerGoroutine, s.Len()) + }) +} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 0165b0f7194d029a6dee9113f82877820ce96c00..8a462eb1496bc6501f6f96d43307aec65eb40e97 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" + "slices" "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/permission" @@ -195,9 +197,8 @@ func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *confi } func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { - var mu sync.Mutex var wg sync.WaitGroup - var result []tools.BaseTool + result := csync.NewSlice[tools.BaseTool]() for name, m := range cfg.MCP { if m.Disabled { slog.Debug("skipping disabled mcp", "name", name) @@ -218,9 +219,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPHttp: c, err := client.NewStreamableHttpClient( m.URL, @@ -230,9 +229,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPSse: c, err := client.NewSSEMCPClient( m.URL, @@ -242,12 +239,10 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) } }(name, m) } wg.Wait() - return result + return slices.Collect(result.Seq()) } diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 4a2661bb9f663d9f93cf0371ac5d71dd513392c7..8c87482a71679f5bc682e6fdd8c1f5a03b89c184 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" ) @@ -74,8 +75,7 @@ func processContextPaths(workDir string, paths []string) string { ) // Track processed files to avoid duplicates - processedFiles := make(map[string]bool) - var processedMutex sync.Mutex + processedFiles := csync.NewMap[string, bool]() for _, path := range paths { wg.Add(1) @@ -106,14 +106,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(path) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) if result := processFile(path); result != "" { resultCh <- result } @@ -126,14 +120,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) result := processFile(fullPath) if result != "" { resultCh <- result diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 080870937bee98be852979748dab456fa6a53b66..43cbd06dcc2136420385f40f97fb5e4e579c539b 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -12,6 +12,7 @@ import ( "github.com/bmatcuk/doublestar/v4" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -25,8 +26,7 @@ type WorkspaceWatcher struct { workspacePath string debounceTime time.Duration - debounceMap map[string]*time.Timer - debounceMu sync.Mutex + debounceMap *csync.Map[string, *time.Timer] // File watchers registered by the server registrations []protocol.FileSystemWatcher @@ -46,7 +46,7 @@ func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher { name: name, client: client, debounceTime: 300 * time.Millisecond, - debounceMap: make(map[string]*time.Timer), + debounceMap: csync.NewMap[string, *time.Timer](), registrations: []protocol.FileSystemWatcher{}, } } @@ -635,26 +635,21 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt // debounceHandleFileEvent handles file events with debouncing to reduce notifications func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) { - w.debounceMu.Lock() - defer w.debounceMu.Unlock() - // Create a unique key based on URI and change type key := fmt.Sprintf("%s:%d", uri, changeType) // Cancel existing timer if any - if timer, exists := w.debounceMap[key]; exists { + if timer, exists := w.debounceMap.Get(key); exists { timer.Stop() } // Create new timer - w.debounceMap[key] = time.AfterFunc(w.debounceTime, func() { + w.debounceMap.Set(key, time.AfterFunc(w.debounceTime, func() { w.handleFileEvent(ctx, uri, changeType) // Cleanup timer after execution - w.debounceMu.Lock() - delete(w.debounceMap, key) - w.debounceMu.Unlock() - }) + w.debounceMap.Del(key) + })) } // handleFileEvent sends file change notifications