From b92d8800e3451a908bcb568bc97d9a3fb5a53232 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 25 Jul 2025 11:25:45 -0300 Subject: [PATCH] refactor: use csync.Map instead of sync.Map --- internal/config/config.go | 2 +- internal/csync/maps.go | 24 ++ internal/csync/maps_test.go | 225 ++++++++++++++++++ internal/llm/agent/agent.go | 54 ++--- internal/permission/permission.go | 20 +- .../tui/components/chat/sidebar/sidebar.go | 73 +++--- 6 files changed, 314 insertions(+), 84 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 9a0da2a376abc88c5e584d7d39744da6f1890ce3..0f9fc99b5ce7677b0009933c447c0f7959825501 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -270,7 +270,7 @@ func (c *Config) WorkingDir() string { func (c *Config) EnabledProviders() []ProviderConfig { var enabled []ProviderConfig - for _, p := range c.Providers.Seq2() { + for p := range c.Providers.Seq() { if !p.Disable { enabled = append(enabled, p) } diff --git a/internal/csync/maps.go b/internal/csync/maps.go index 45e426630a4e50b45125d41dcca54d4e183b4f6f..ddc735b2624500899ca25670f7934326dfd9bdf3 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -56,6 +56,15 @@ func (m *Map[K, V]) Len() int { return len(m.inner) } +// Take gets an item and then deletes it. +func (m *Map[K, V]) Take(key K) (V, bool) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.inner[key] + delete(m.inner, key) + return v, ok +} + // Seq2 returns an iter.Seq2 that yields key-value pairs from the map. func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { dst := make(map[K]V) @@ -71,6 +80,21 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { } } +// Seq returns an iter.Seq that yields values from the map. +func (m *Map[K, V]) Seq() iter.Seq[V] { + dst := make(map[K]V) + m.mu.RLock() + maps.Copy(dst, m.inner) + m.mu.RUnlock() + return func(yield func(V) bool) { + for _, v := range dst { + if !yield(v) { + return + } + } + } +} + var ( _ json.Unmarshaler = &Map[string, any]{} _ json.Marshaler = &Map[string, any]{} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 73e6f1db245231e9fad82103366d96a326acc4f6..2882b30cb6ab8f71f969f03687766ca190467c0d 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -110,6 +110,72 @@ func TestMap_Len(t *testing.T) { assert.Equal(t, 0, m.Len()) } +func TestMap_Take(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + _, exists := m.Get("key1") + assert.False(t, exists) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) +} + +func TestMap_Take_NonexistentKey(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 1, m.Len()) + + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Take_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Take_SameKeyTwice(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + + value, ok := m.Take("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 0, m.Len()) + + value, ok = m.Take("key1") + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, 0, m.Len()) +} + func TestMap_Seq2(t *testing.T) { t.Parallel() @@ -158,6 +224,57 @@ func TestMap_Seq2_EmptyMap(t *testing.T) { assert.Equal(t, 0, count) } +func TestMap_Seq(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := make([]int, 0) + for v := range m.Seq() { + collected = append(collected, v) + } + + assert.Equal(t, 3, len(collected)) + assert.Contains(t, collected, 1) + assert.Contains(t, collected, 2) + assert.Contains(t, collected, 3) +} + +func TestMap_Seq_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq() { + count++ + } + + assert.Equal(t, 0, count) +} + func TestMap_MarshalJSON(t *testing.T) { t.Parallel() @@ -371,6 +488,82 @@ func TestMap_ConcurrentSeq2(t *testing.T) { wg.Wait() } +func TestMap_ConcurrentSeq(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + values := make(map[int]bool) + for v := range m.Seq() { + values[v] = true + count++ + } + assert.Equal(t, 100, count) + for i := range 100 { + assert.True(t, values[i*2]) + } + }() + } + + wg.Wait() +} + +func TestMap_ConcurrentTake(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numItems = 1000 + + for i := range numItems { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numWorkers = 10 + taken := make([][]int, numWorkers) + + wg.Add(numWorkers) + for i := range numWorkers { + go func(workerID int) { + defer wg.Done() + taken[workerID] = make([]int, 0) + for j := workerID; j < numItems; j += numWorkers { + if value, ok := m.Take(j); ok { + taken[workerID] = append(taken[workerID], value) + } + } + }(i) + } + + wg.Wait() + + assert.Equal(t, 0, m.Len()) + + allTaken := make(map[int]bool) + for _, workerTaken := range taken { + for _, value := range workerTaken { + assert.False(t, allTaken[value], "Value %d was taken multiple times", value) + allTaken[value] = true + } + } + + assert.Equal(t, numItems, len(allTaken)) + for i := range numItems { + assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2) + } +} + func TestMap_TypeSafety(t *testing.T) { t.Parallel() @@ -431,6 +624,38 @@ func BenchmarkMap_Seq2(b *testing.B) { } } +func BenchmarkMap_Seq(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq() { + } + } +} + +func BenchmarkMap_Take(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + for i := 0; b.Loop(); i++ { + key := i % 1000 + m.Take(key) + if i%1000 == 999 { + b.StopTimer() + for j := range 1000 { + m.Set(j, j*2) + } + b.StartTimer() + } + } +} + func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { m := NewMap[int, int]() for i := range 1000 { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 2c3876ccac9ed028b1714ed96b0c6de0cce007c9..67bc861bbe5b106fa134abbf492cad855d2d362a 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "log/slog" "slices" "strings" - "sync" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -78,7 +77,7 @@ type agent struct { summarizeProvider provider.Provider summarizeProviderID string - activeRequests sync.Map + activeRequests *csync.Map[string, context.CancelFunc] } var agentPromptMap = map[string]prompt.PromptID{ @@ -222,7 +221,7 @@ func NewAgent( titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, + activeRequests: csync.NewMap[string, context.CancelFunc](), tools: csync.NewLazySlice(toolFn), }, nil } @@ -233,38 +232,30 @@ func (a *agent) Model() catwalk.Model { func (a *agent) Cancel(sessionID string) { // Cancel regular requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Request cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { + slog.Info("Request cancellation initiated", "session_id", sessionID) + cancel() } // Also check for summarize requests - if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists { - if cancel, ok := cancelFunc.(context.CancelFunc); ok { - slog.Info("Summarize cancellation initiated", "session_id", sessionID) - cancel() - } + if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { + slog.Info("Summarize cancellation initiated", "session_id", sessionID) + cancel() } } func (a *agent) IsBusy() bool { - busy := false - a.activeRequests.Range(func(key, value any) bool { - if cancelFunc, ok := value.(context.CancelFunc); ok { - if cancelFunc != nil { - busy = true - return false - } + var busy bool + for cancelFunc := range a.activeRequests.Seq() { + if cancelFunc != nil { + busy = true } - return true - }) + } return busy } func (a *agent) IsSessionBusy(sessionID string) bool { - _, busy := a.activeRequests.Load(sessionID) + _, busy := a.activeRequests.Get(sessionID) return busy } @@ -335,7 +326,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Store(sessionID, cancel) + a.activeRequests.Set(sessionID, cancel) go func() { slog.Debug("Request started", "sessionID", sessionID) defer log.RecoverPanic("agent.Run", func() { @@ -350,7 +341,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac slog.Error(result.Error.Error()) } slog.Debug("Request completed", "sessionID", sessionID) - a.activeRequests.Delete(sessionID) + a.activeRequests.Del(sessionID) cancel() a.Publish(pubsub.CreatedEvent, result) events <- result @@ -682,10 +673,10 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { summarizeCtx, cancel := context.WithCancel(ctx) // Store the cancel function in activeRequests to allow cancellation - a.activeRequests.Store(sessionID+"-summarize", cancel) + a.activeRequests.Set(sessionID+"-summarize", cancel) go func() { - defer a.activeRequests.Delete(sessionID + "-summarize") + defer a.activeRequests.Del(sessionID + "-summarize") defer cancel() event := AgentEvent{ Type: AgentEventTypeSummarize, @@ -850,10 +841,9 @@ func (a *agent) CancelAll() { if !a.IsBusy() { return } - a.activeRequests.Range(func(key, value any) bool { - a.Cancel(key.(string)) // key is sessionID - return true - }) + for key := range a.activeRequests.Seq2() { + a.Cancel(key) // key is sessionID + } timeout := time.After(5 * time.Second) for a.IsBusy() { @@ -907,7 +897,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers.Seq2() { + for p := range cfg.Providers.Seq() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/permission/permission.go b/internal/permission/permission.go index cd149a49890b54086bd52e562eed0d44f00c407e..c5d001075a8cd01a91cccee0afcd44f89a5d4bcc 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,6 +6,7 @@ import ( "slices" "sync" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) @@ -46,7 +47,7 @@ type permissionService struct { workingDir string sessionPermissions []PermissionRequest sessionPermissionsMu sync.RWMutex - pendingRequests sync.Map + pendingRequests *csync.Map[string, chan bool] autoApproveSessions []string autoApproveSessionsMu sync.RWMutex skip bool @@ -54,9 +55,9 @@ type permissionService struct { } func (s *permissionService) GrantPersistent(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } s.sessionPermissionsMu.Lock() @@ -65,16 +66,16 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { } func (s *permissionService) Grant(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- true + respCh <- true } } func (s *permissionService) Deny(permission PermissionRequest) { - respCh, ok := s.pendingRequests.Load(permission.ID) + respCh, ok := s.pendingRequests.Get(permission.ID) if ok { - respCh.(chan bool) <- false + respCh <- false } } @@ -122,8 +123,8 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { respCh := make(chan bool, 1) - s.pendingRequests.Store(permission.ID, respCh) - defer s.pendingRequests.Delete(permission.ID) + s.pendingRequests.Set(permission.ID, respCh) + defer s.pendingRequests.Del(permission.ID) s.Publish(pubsub.CreatedEvent, permission) @@ -144,5 +145,6 @@ func NewPermissionService(workingDir string, skip bool, allowedTools []string) S sessionPermissions: make([]PermissionRequest, 0), skip: skip, allowedTools: allowedTools, + pendingRequests: csync.NewMap[string, chan bool](), } } diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 1aa239bdc15cec6898a4cba1e4dc7a867b5e4ce0..3ab2e9563420e5f3c6365bc71a66e35ab7b79f11 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "os" + "slices" "sort" "strings" - "sync" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -71,8 +72,7 @@ type sidebarCmp struct { lspClients map[string]*lsp.Client compactMode bool history history.Service - // Using a sync map here because we might receive file history events concurrently - files sync.Map + files *csync.Map[string, SessionFile] } func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar { @@ -80,6 +80,7 @@ func New(history history.Service, lspClients map[string]*lsp.Client, compact boo lspClients: lspClients, history: history, compactMode: compact, + files: csync.NewMap[string, SessionFile](), } } @@ -90,9 +91,9 @@ func (m *sidebarCmp) Init() tea.Cmd { func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case SessionFilesMsg: - m.files = sync.Map{} + m.files = csync.NewMap[string, SessionFile]() for _, file := range msg.Files { - m.files.Store(file.FilePath, file) + m.files.Set(file.FilePath, file) } return m, nil @@ -178,31 +179,29 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te return func() tea.Msg { file := event.Payload found := false - m.files.Range(func(key, value any) bool { - existing := value.(SessionFile) - if existing.FilePath == file.Path { - if existing.History.latestVersion.Version < file.Version { - existing.History.latestVersion = file - } else if file.Version == 0 { - existing.History.initialVersion = file - } else { - // If the version is not greater than the latest, we ignore it - return true - } - before := existing.History.initialVersion.Content - after := existing.History.latestVersion.Content - path := existing.History.initialVersion.Path - cwd := config.Get().WorkingDir() - path = strings.TrimPrefix(path, cwd) - _, additions, deletions := diff.GenerateDiff(before, after, path) - existing.Additions = additions - existing.Deletions = deletions - m.files.Store(file.Path, existing) - found = true - return false + for existing := range m.files.Seq() { + if existing.FilePath != file.Path { + continue } - return true - }) + if existing.History.latestVersion.Version < file.Version { + existing.History.latestVersion = file + } else if file.Version == 0 { + existing.History.initialVersion = file + } else { + // If the version is not greater than the latest, we ignore it + continue + } + before := existing.History.initialVersion.Content + after := existing.History.latestVersion.Content + path := existing.History.initialVersion.Path + cwd := config.Get().WorkingDir() + path = strings.TrimPrefix(path, cwd) + _, additions, deletions := diff.GenerateDiff(before, after, path) + existing.Additions = additions + existing.Deletions = deletions + m.files.Set(file.Path, existing) + found = true + } if found { return nil } @@ -215,7 +214,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te Additions: 0, Deletions: 0, } - m.files.Store(file.Path, sf) + m.files.Set(file.Path, sf) return nil } } @@ -386,12 +385,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string { section := t.S().Subtle.Render("Modified Files") - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { content := lipgloss.JoinVertical( @@ -620,12 +614,7 @@ func (m *sidebarCmp) filesBlock() string { core.Section("Modified Files", m.getMaxWidth()), ) - files := make([]SessionFile, 0) - m.files.Range(func(key, value any) bool { - file := value.(SessionFile) - files = append(files, file) - return true // continue iterating - }) + files := slices.Collect(m.files.Seq()) if len(files) == 0 { return lipgloss.JoinVertical( lipgloss.Left,