Detailed changes
@@ -270,7 +270,7 @@ func (c *Config) WorkingDir() string {
func (c *Config) EnabledProviders() []ProviderConfig {
var enabled []ProviderConfig
- for _, p := range c.Providers.Seq2() {
+ for p := range c.Providers.Seq() {
if !p.Disable {
enabled = append(enabled, p)
}
@@ -56,6 +56,15 @@ func (m *Map[K, V]) Len() int {
return len(m.inner)
}
+// Take gets an item and then deletes it.
+func (m *Map[K, V]) Take(key K) (V, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ v, ok := m.inner[key]
+ delete(m.inner, key)
+ return v, ok
+}
+
// Seq2 returns an iter.Seq2 that yields key-value pairs from the map.
func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
dst := make(map[K]V)
@@ -71,6 +80,21 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
}
}
+// Seq returns an iter.Seq that yields values from the map.
+func (m *Map[K, V]) Seq() iter.Seq[V] {
+ dst := make(map[K]V)
+ m.mu.RLock()
+ maps.Copy(dst, m.inner)
+ m.mu.RUnlock()
+ return func(yield func(V) bool) {
+ for _, v := range dst {
+ if !yield(v) {
+ return
+ }
+ }
+ }
+}
+
var (
_ json.Unmarshaler = &Map[string, any]{}
_ json.Marshaler = &Map[string, any]{}
@@ -110,6 +110,72 @@ func TestMap_Len(t *testing.T) {
assert.Equal(t, 0, m.Len())
}
+func TestMap_Take(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 42)
+ m.Set("key2", 100)
+
+ assert.Equal(t, 2, m.Len())
+
+ value, ok := m.Take("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+ assert.Equal(t, 1, m.Len())
+
+ _, exists := m.Get("key1")
+ assert.False(t, exists)
+
+ value, ok = m.Get("key2")
+ assert.True(t, ok)
+ assert.Equal(t, 100, value)
+}
+
+func TestMap_Take_NonexistentKey(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 42)
+
+ value, ok := m.Take("nonexistent")
+ assert.False(t, ok)
+ assert.Equal(t, 0, value)
+ assert.Equal(t, 1, m.Len())
+
+ value, ok = m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+}
+
+func TestMap_Take_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ value, ok := m.Take("key1")
+ assert.False(t, ok)
+ assert.Equal(t, 0, value)
+ assert.Equal(t, 0, m.Len())
+}
+
+func TestMap_Take_SameKeyTwice(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 42)
+
+ value, ok := m.Take("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+ assert.Equal(t, 0, m.Len())
+
+ value, ok = m.Take("key1")
+ assert.False(t, ok)
+ assert.Equal(t, 0, value)
+ assert.Equal(t, 0, m.Len())
+}
+
func TestMap_Seq2(t *testing.T) {
t.Parallel()
@@ -158,6 +224,57 @@ func TestMap_Seq2_EmptyMap(t *testing.T) {
assert.Equal(t, 0, count)
}
+func TestMap_Seq(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 1)
+ m.Set("key2", 2)
+ m.Set("key3", 3)
+
+ collected := make([]int, 0)
+ for v := range m.Seq() {
+ collected = append(collected, v)
+ }
+
+ assert.Equal(t, 3, len(collected))
+ assert.Contains(t, collected, 1)
+ assert.Contains(t, collected, 2)
+ assert.Contains(t, collected, 3)
+}
+
+func TestMap_Seq_EarlyReturn(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 1)
+ m.Set("key2", 2)
+ m.Set("key3", 3)
+
+ count := 0
+ for range m.Seq() {
+ count++
+ if count == 2 {
+ break
+ }
+ }
+
+ assert.Equal(t, 2, count)
+}
+
+func TestMap_Seq_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ count := 0
+ for range m.Seq() {
+ count++
+ }
+
+ assert.Equal(t, 0, count)
+}
+
func TestMap_MarshalJSON(t *testing.T) {
t.Parallel()
@@ -371,6 +488,82 @@ func TestMap_ConcurrentSeq2(t *testing.T) {
wg.Wait()
}
+func TestMap_ConcurrentSeq(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[int, int]()
+ for i := range 100 {
+ m.Set(i, i*2)
+ }
+
+ var wg sync.WaitGroup
+ const numIterators = 10
+
+ wg.Add(numIterators)
+ for range numIterators {
+ go func() {
+ defer wg.Done()
+ count := 0
+ values := make(map[int]bool)
+ for v := range m.Seq() {
+ values[v] = true
+ count++
+ }
+ assert.Equal(t, 100, count)
+ for i := range 100 {
+ assert.True(t, values[i*2])
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+func TestMap_ConcurrentTake(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[int, int]()
+ const numItems = 1000
+
+ for i := range numItems {
+ m.Set(i, i*2)
+ }
+
+ var wg sync.WaitGroup
+ const numWorkers = 10
+ taken := make([][]int, numWorkers)
+
+ wg.Add(numWorkers)
+ for i := range numWorkers {
+ go func(workerID int) {
+ defer wg.Done()
+ taken[workerID] = make([]int, 0)
+ for j := workerID; j < numItems; j += numWorkers {
+ if value, ok := m.Take(j); ok {
+ taken[workerID] = append(taken[workerID], value)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+
+ assert.Equal(t, 0, m.Len())
+
+ allTaken := make(map[int]bool)
+ for _, workerTaken := range taken {
+ for _, value := range workerTaken {
+ assert.False(t, allTaken[value], "Value %d was taken multiple times", value)
+ allTaken[value] = true
+ }
+ }
+
+ assert.Equal(t, numItems, len(allTaken))
+ for i := range numItems {
+ assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
+ }
+}
+
func TestMap_TypeSafety(t *testing.T) {
t.Parallel()
@@ -431,6 +624,38 @@ func BenchmarkMap_Seq2(b *testing.B) {
}
}
+func BenchmarkMap_Seq(b *testing.B) {
+ m := NewMap[int, int]()
+ for i := range 1000 {
+ m.Set(i, i*2)
+ }
+
+ for b.Loop() {
+ for range m.Seq() {
+ }
+ }
+}
+
+func BenchmarkMap_Take(b *testing.B) {
+ m := NewMap[int, int]()
+ for i := range 1000 {
+ m.Set(i, i*2)
+ }
+
+ b.ResetTimer()
+ for i := 0; b.Loop(); i++ {
+ key := i % 1000
+ m.Take(key)
+ if i%1000 == 999 {
+ b.StopTimer()
+ for j := range 1000 {
+ m.Set(j, j*2)
+ }
+ b.StartTimer()
+ }
+ }
+}
+
func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
m := NewMap[int, int]()
for i := range 1000 {
@@ -7,7 +7,6 @@ import (
"log/slog"
"slices"
"strings"
- "sync"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -78,7 +77,7 @@ type agent struct {
summarizeProvider provider.Provider
summarizeProviderID string
- activeRequests sync.Map
+ activeRequests *csync.Map[string, context.CancelFunc]
}
var agentPromptMap = map[string]prompt.PromptID{
@@ -222,7 +221,7 @@ func NewAgent(
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
summarizeProviderID: string(smallModelProviderCfg.ID),
- activeRequests: sync.Map{},
+ activeRequests: csync.NewMap[string, context.CancelFunc](),
tools: csync.NewLazySlice(toolFn),
}, nil
}
@@ -233,38 +232,30 @@ func (a *agent) Model() catwalk.Model {
func (a *agent) Cancel(sessionID string) {
// Cancel regular requests
- if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
- if cancel, ok := cancelFunc.(context.CancelFunc); ok {
- slog.Info("Request cancellation initiated", "session_id", sessionID)
- cancel()
- }
+ if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
+ slog.Info("Request cancellation initiated", "session_id", sessionID)
+ cancel()
}
// Also check for summarize requests
- if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
- if cancel, ok := cancelFunc.(context.CancelFunc); ok {
- slog.Info("Summarize cancellation initiated", "session_id", sessionID)
- cancel()
- }
+ if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
+ slog.Info("Summarize cancellation initiated", "session_id", sessionID)
+ cancel()
}
}
func (a *agent) IsBusy() bool {
- busy := false
- a.activeRequests.Range(func(key, value any) bool {
- if cancelFunc, ok := value.(context.CancelFunc); ok {
- if cancelFunc != nil {
- busy = true
- return false
- }
+ var busy bool
+ for cancelFunc := range a.activeRequests.Seq() {
+ if cancelFunc != nil {
+ busy = true
}
- return true
- })
+ }
return busy
}
func (a *agent) IsSessionBusy(sessionID string) bool {
- _, busy := a.activeRequests.Load(sessionID)
+ _, busy := a.activeRequests.Get(sessionID)
return busy
}
@@ -335,7 +326,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Store(sessionID, cancel)
+ a.activeRequests.Set(sessionID, cancel)
go func() {
slog.Debug("Request started", "sessionID", sessionID)
defer log.RecoverPanic("agent.Run", func() {
@@ -350,7 +341,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
slog.Error(result.Error.Error())
}
slog.Debug("Request completed", "sessionID", sessionID)
- a.activeRequests.Delete(sessionID)
+ a.activeRequests.Del(sessionID)
cancel()
a.Publish(pubsub.CreatedEvent, result)
events <- result
@@ -682,10 +673,10 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
summarizeCtx, cancel := context.WithCancel(ctx)
// Store the cancel function in activeRequests to allow cancellation
- a.activeRequests.Store(sessionID+"-summarize", cancel)
+ a.activeRequests.Set(sessionID+"-summarize", cancel)
go func() {
- defer a.activeRequests.Delete(sessionID + "-summarize")
+ defer a.activeRequests.Del(sessionID + "-summarize")
defer cancel()
event := AgentEvent{
Type: AgentEventTypeSummarize,
@@ -850,10 +841,9 @@ func (a *agent) CancelAll() {
if !a.IsBusy() {
return
}
- a.activeRequests.Range(func(key, value any) bool {
- a.Cancel(key.(string)) // key is sessionID
- return true
- })
+ for key := range a.activeRequests.Seq2() {
+ a.Cancel(key) // key is sessionID
+ }
timeout := time.After(5 * time.Second)
for a.IsBusy() {
@@ -907,7 +897,7 @@ func (a *agent) UpdateModel() error {
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
- for _, p := range cfg.Providers.Seq2() {
+ for p := range cfg.Providers.Seq() {
if p.ID == smallModelCfg.Provider {
smallModelProviderCfg = p
break
@@ -6,6 +6,7 @@ import (
"slices"
"sync"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
@@ -46,7 +47,7 @@ type permissionService struct {
workingDir string
sessionPermissions []PermissionRequest
sessionPermissionsMu sync.RWMutex
- pendingRequests sync.Map
+ pendingRequests *csync.Map[string, chan bool]
autoApproveSessions []string
autoApproveSessionsMu sync.RWMutex
skip bool
@@ -54,9 +55,9 @@ type permissionService struct {
}
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
- respCh, ok := s.pendingRequests.Load(permission.ID)
+ respCh, ok := s.pendingRequests.Get(permission.ID)
if ok {
- respCh.(chan bool) <- true
+ respCh <- true
}
s.sessionPermissionsMu.Lock()
@@ -65,16 +66,16 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
}
func (s *permissionService) Grant(permission PermissionRequest) {
- respCh, ok := s.pendingRequests.Load(permission.ID)
+ respCh, ok := s.pendingRequests.Get(permission.ID)
if ok {
- respCh.(chan bool) <- true
+ respCh <- true
}
}
func (s *permissionService) Deny(permission PermissionRequest) {
- respCh, ok := s.pendingRequests.Load(permission.ID)
+ respCh, ok := s.pendingRequests.Get(permission.ID)
if ok {
- respCh.(chan bool) <- false
+ respCh <- false
}
}
@@ -122,8 +123,8 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
respCh := make(chan bool, 1)
- s.pendingRequests.Store(permission.ID, respCh)
- defer s.pendingRequests.Delete(permission.ID)
+ s.pendingRequests.Set(permission.ID, respCh)
+ defer s.pendingRequests.Del(permission.ID)
s.Publish(pubsub.CreatedEvent, permission)
@@ -144,5 +145,6 @@ func NewPermissionService(workingDir string, skip bool, allowedTools []string) S
sessionPermissions: make([]PermissionRequest, 0),
skip: skip,
allowedTools: allowedTools,
+ pendingRequests: csync.NewMap[string, chan bool](),
}
}
@@ -4,13 +4,14 @@ import (
"context"
"fmt"
"os"
+ "slices"
"sort"
"strings"
- "sync"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
@@ -71,8 +72,7 @@ type sidebarCmp struct {
lspClients map[string]*lsp.Client
compactMode bool
history history.Service
- // Using a sync map here because we might receive file history events concurrently
- files sync.Map
+ files *csync.Map[string, SessionFile]
}
func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar {
@@ -80,6 +80,7 @@ func New(history history.Service, lspClients map[string]*lsp.Client, compact boo
lspClients: lspClients,
history: history,
compactMode: compact,
+ files: csync.NewMap[string, SessionFile](),
}
}
@@ -90,9 +91,9 @@ func (m *sidebarCmp) Init() tea.Cmd {
func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case SessionFilesMsg:
- m.files = sync.Map{}
+ m.files = csync.NewMap[string, SessionFile]()
for _, file := range msg.Files {
- m.files.Store(file.FilePath, file)
+ m.files.Set(file.FilePath, file)
}
return m, nil
@@ -178,31 +179,29 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
return func() tea.Msg {
file := event.Payload
found := false
- m.files.Range(func(key, value any) bool {
- existing := value.(SessionFile)
- if existing.FilePath == file.Path {
- if existing.History.latestVersion.Version < file.Version {
- existing.History.latestVersion = file
- } else if file.Version == 0 {
- existing.History.initialVersion = file
- } else {
- // If the version is not greater than the latest, we ignore it
- return true
- }
- before := existing.History.initialVersion.Content
- after := existing.History.latestVersion.Content
- path := existing.History.initialVersion.Path
- cwd := config.Get().WorkingDir()
- path = strings.TrimPrefix(path, cwd)
- _, additions, deletions := diff.GenerateDiff(before, after, path)
- existing.Additions = additions
- existing.Deletions = deletions
- m.files.Store(file.Path, existing)
- found = true
- return false
+ for existing := range m.files.Seq() {
+ if existing.FilePath != file.Path {
+ continue
}
- return true
- })
+ if existing.History.latestVersion.Version < file.Version {
+ existing.History.latestVersion = file
+ } else if file.Version == 0 {
+ existing.History.initialVersion = file
+ } else {
+ // If the version is not greater than the latest, we ignore it
+ continue
+ }
+ before := existing.History.initialVersion.Content
+ after := existing.History.latestVersion.Content
+ path := existing.History.initialVersion.Path
+ cwd := config.Get().WorkingDir()
+ path = strings.TrimPrefix(path, cwd)
+ _, additions, deletions := diff.GenerateDiff(before, after, path)
+ existing.Additions = additions
+ existing.Deletions = deletions
+ m.files.Set(file.Path, existing)
+ found = true
+ }
if found {
return nil
}
@@ -215,7 +214,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
Additions: 0,
Deletions: 0,
}
- m.files.Store(file.Path, sf)
+ m.files.Set(file.Path, sf)
return nil
}
}
@@ -386,12 +385,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string {
section := t.S().Subtle.Render("Modified Files")
- files := make([]SessionFile, 0)
- m.files.Range(func(key, value any) bool {
- file := value.(SessionFile)
- files = append(files, file)
- return true
- })
+ files := slices.Collect(m.files.Seq())
if len(files) == 0 {
content := lipgloss.JoinVertical(
@@ -620,12 +614,7 @@ func (m *sidebarCmp) filesBlock() string {
core.Section("Modified Files", m.getMaxWidth()),
)
- files := make([]SessionFile, 0)
- m.files.Range(func(key, value any) bool {
- file := value.(SessionFile)
- files = append(files, file)
- return true // continue iterating
- })
+ files := slices.Collect(m.files.Seq())
if len(files) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,