Detailed changes
@@ -12,6 +12,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
@@ -37,8 +38,7 @@ type App struct {
clientsMutex sync.RWMutex
- watcherCancelFuncs []context.CancelFunc
- cancelFuncsMutex sync.Mutex
+ watcherCancelFuncs *csync.Slice[context.CancelFunc]
lspWatcherWG sync.WaitGroup
config *config.Config
@@ -76,6 +76,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
config: cfg,
+ watcherCancelFuncs: csync.NewSlice[context.CancelFunc](),
+
events: make(chan tea.Msg, 100),
serviceEventsWG: &sync.WaitGroup{},
tuiWG: &sync.WaitGroup{},
@@ -305,11 +307,9 @@ func (app *App) Shutdown() {
app.CoderAgent.CancelAll()
}
- app.cancelFuncsMutex.Lock()
- for _, cancel := range app.watcherCancelFuncs {
+ for cancel := range app.watcherCancelFuncs.Seq() {
cancel()
}
- app.cancelFuncsMutex.Unlock()
// Wait for all LSP watchers to finish.
app.lspWatcherWG.Wait()
@@ -63,9 +63,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient)
// Store the cancel function to be called during cleanup.
- app.cancelFuncsMutex.Lock()
- app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc)
- app.cancelFuncsMutex.Unlock()
+ app.watcherCancelFuncs.Append(cancelFunc)
// Add to map with mutex protection before starting goroutine
app.clientsMutex.Lock()
@@ -82,12 +82,8 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
// Seq returns an iter.Seq that yields values from the map.
func (m *Map[K, V]) Seq() iter.Seq[V] {
- dst := make(map[K]V)
- m.mu.RLock()
- maps.Copy(dst, m.inner)
- m.mu.RUnlock()
return func(yield func(V) bool) {
- for _, v := range dst {
+ for _, v := range m.Seq2() {
if !yield(v) {
return
}
@@ -2,6 +2,7 @@ package csync
import (
"iter"
+ "slices"
"sync"
)
@@ -34,3 +35,129 @@ func (s *LazySlice[K]) Seq() iter.Seq[K] {
}
}
}
+
+// Slice is a thread-safe slice implementation that provides concurrent access.
+type Slice[T any] struct {
+ inner []T
+ mu sync.RWMutex
+}
+
+// NewSlice creates a new thread-safe slice.
+func NewSlice[T any]() *Slice[T] {
+ return &Slice[T]{
+ inner: make([]T, 0),
+ }
+}
+
+// NewSliceFrom creates a new thread-safe slice from an existing slice.
+func NewSliceFrom[T any](s []T) *Slice[T] {
+ inner := make([]T, len(s))
+ copy(inner, s)
+ return &Slice[T]{
+ inner: inner,
+ }
+}
+
+// Append adds an element to the end of the slice.
+func (s *Slice[T]) Append(items ...T) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.inner = append(s.inner, items...)
+}
+
+// Prepend adds an element to the beginning of the slice.
+func (s *Slice[T]) Prepend(item T) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.inner = append([]T{item}, s.inner...)
+}
+
+// Delete removes the element at the specified index.
+func (s *Slice[T]) Delete(index int) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if index < 0 || index >= len(s.inner) {
+ return false
+ }
+ s.inner = slices.Delete(s.inner, index, index+1)
+ return true
+}
+
+// Get returns the element at the specified index.
+func (s *Slice[T]) Get(index int) (T, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ var zero T
+ if index < 0 || index >= len(s.inner) {
+ return zero, false
+ }
+ return s.inner[index], true
+}
+
+// Set updates the element at the specified index.
+func (s *Slice[T]) Set(index int, item T) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if index < 0 || index >= len(s.inner) {
+ return false
+ }
+ s.inner[index] = item
+ return true
+}
+
+// Len returns the number of elements in the slice.
+func (s *Slice[T]) Len() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return len(s.inner)
+}
+
+// Slice returns a copy of the underlying slice.
+func (s *Slice[T]) Slice() []T {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ result := make([]T, len(s.inner))
+ copy(result, s.inner)
+ return result
+}
+
+// SetSlice replaces the entire slice with a new one.
+func (s *Slice[T]) SetSlice(items []T) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.inner = make([]T, len(items))
+ copy(s.inner, items)
+}
+
+// Clear removes all elements from the slice.
+func (s *Slice[T]) Clear() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.inner = s.inner[:0]
+}
+
+// Seq returns an iterator that yields elements from the slice.
+func (s *Slice[T]) Seq() iter.Seq[T] {
+ return func(yield func(T) bool) {
+ for _, v := range s.Seq2() {
+ if !yield(v) {
+ return
+ }
+ }
+ }
+}
+
+// Seq2 returns an iterator that yields index-value pairs from the slice.
+func (s *Slice[T]) Seq2() iter.Seq2[int, T] {
+ s.mu.RLock()
+ items := make([]T, len(s.inner))
+ copy(items, s.inner)
+ s.mu.RUnlock()
+ return func(yield func(int, T) bool) {
+ for i, v := range items {
+ if !yield(i, v) {
+ return
+ }
+ }
+ }
+}
@@ -1,11 +1,13 @@
package csync
import (
+ "sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestLazySlice_Seq(t *testing.T) {
@@ -85,3 +87,210 @@ func TestLazySlice_EarlyBreak(t *testing.T) {
assert.Equal(t, []string{"a", "b"}, result)
}
+
+func TestSlice(t *testing.T) {
+ t.Run("NewSlice", func(t *testing.T) {
+ s := NewSlice[int]()
+ assert.Equal(t, 0, s.Len())
+ })
+
+ t.Run("NewSliceFrom", func(t *testing.T) {
+ original := []int{1, 2, 3}
+ s := NewSliceFrom(original)
+ assert.Equal(t, 3, s.Len())
+
+ // Verify it's a copy, not a reference
+ original[0] = 999
+ val, ok := s.Get(0)
+ require.True(t, ok)
+ assert.Equal(t, 1, val)
+ })
+
+ t.Run("Append", func(t *testing.T) {
+ s := NewSlice[string]()
+ s.Append("hello")
+ s.Append("world")
+
+ assert.Equal(t, 2, s.Len())
+ val, ok := s.Get(0)
+ require.True(t, ok)
+ assert.Equal(t, "hello", val)
+
+ val, ok = s.Get(1)
+ require.True(t, ok)
+ assert.Equal(t, "world", val)
+ })
+
+ t.Run("Prepend", func(t *testing.T) {
+ s := NewSlice[string]()
+ s.Append("world")
+ s.Prepend("hello")
+
+ assert.Equal(t, 2, s.Len())
+ val, ok := s.Get(0)
+ require.True(t, ok)
+ assert.Equal(t, "hello", val)
+
+ val, ok = s.Get(1)
+ require.True(t, ok)
+ assert.Equal(t, "world", val)
+ })
+
+ t.Run("Delete", func(t *testing.T) {
+ s := NewSliceFrom([]int{1, 2, 3, 4, 5})
+
+ // Delete middle element
+ ok := s.Delete(2)
+ assert.True(t, ok)
+ assert.Equal(t, 4, s.Len())
+
+ expected := []int{1, 2, 4, 5}
+ actual := s.Slice()
+ assert.Equal(t, expected, actual)
+
+ // Delete out of bounds
+ ok = s.Delete(10)
+ assert.False(t, ok)
+ assert.Equal(t, 4, s.Len())
+
+ // Delete negative index
+ ok = s.Delete(-1)
+ assert.False(t, ok)
+ assert.Equal(t, 4, s.Len())
+ })
+
+ t.Run("Get", func(t *testing.T) {
+ s := NewSliceFrom([]string{"a", "b", "c"})
+
+ val, ok := s.Get(1)
+ require.True(t, ok)
+ assert.Equal(t, "b", val)
+
+ // Out of bounds
+ _, ok = s.Get(10)
+ assert.False(t, ok)
+
+ // Negative index
+ _, ok = s.Get(-1)
+ assert.False(t, ok)
+ })
+
+ t.Run("Set", func(t *testing.T) {
+ s := NewSliceFrom([]string{"a", "b", "c"})
+
+ ok := s.Set(1, "modified")
+ assert.True(t, ok)
+
+ val, ok := s.Get(1)
+ require.True(t, ok)
+ assert.Equal(t, "modified", val)
+
+ // Out of bounds
+ ok = s.Set(10, "invalid")
+ assert.False(t, ok)
+
+ // Negative index
+ ok = s.Set(-1, "invalid")
+ assert.False(t, ok)
+ })
+
+ t.Run("SetSlice", func(t *testing.T) {
+ s := NewSlice[int]()
+ s.Append(1)
+ s.Append(2)
+
+ newItems := []int{10, 20, 30}
+ s.SetSlice(newItems)
+
+ assert.Equal(t, 3, s.Len())
+ assert.Equal(t, newItems, s.Slice())
+
+ // Verify it's a copy
+ newItems[0] = 999
+ val, ok := s.Get(0)
+ require.True(t, ok)
+ assert.Equal(t, 10, val)
+ })
+
+ t.Run("Clear", func(t *testing.T) {
+ s := NewSliceFrom([]int{1, 2, 3})
+ assert.Equal(t, 3, s.Len())
+
+ s.Clear()
+ assert.Equal(t, 0, s.Len())
+ })
+
+ t.Run("Slice", func(t *testing.T) {
+ original := []int{1, 2, 3}
+ s := NewSliceFrom(original)
+
+ copy := s.Slice()
+ assert.Equal(t, original, copy)
+
+ // Verify it's a copy
+ copy[0] = 999
+ val, ok := s.Get(0)
+ require.True(t, ok)
+ assert.Equal(t, 1, val)
+ })
+
+ t.Run("Seq", func(t *testing.T) {
+ s := NewSliceFrom([]int{1, 2, 3})
+
+ var result []int
+ for v := range s.Seq() {
+ result = append(result, v)
+ }
+
+ assert.Equal(t, []int{1, 2, 3}, result)
+ })
+
+ t.Run("SeqWithIndex", func(t *testing.T) {
+ s := NewSliceFrom([]string{"a", "b", "c"})
+
+ var indices []int
+ var values []string
+ for i, v := range s.Seq2() {
+ indices = append(indices, i)
+ values = append(values, v)
+ }
+
+ assert.Equal(t, []int{0, 1, 2}, indices)
+ assert.Equal(t, []string{"a", "b", "c"}, values)
+ })
+
+ t.Run("ConcurrentAccess", func(t *testing.T) {
+ s := NewSlice[int]()
+ const numGoroutines = 100
+ const itemsPerGoroutine = 10
+
+ var wg sync.WaitGroup
+
+ // Concurrent appends
+ for i := 0; i < numGoroutines; i++ {
+ wg.Add(1)
+ go func(start int) {
+ defer wg.Done()
+ for j := 0; j < itemsPerGoroutine; j++ {
+ s.Append(start*itemsPerGoroutine + j)
+ }
+ }(i)
+ }
+
+ // Concurrent reads
+ for i := 0; i < numGoroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < itemsPerGoroutine; j++ {
+ s.Len() // Just read the length
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ // Should have all items
+ assert.Equal(t, numGoroutines*itemsPerGoroutine, s.Len())
+ })
+}
@@ -5,9 +5,11 @@ import (
"encoding/json"
"fmt"
"log/slog"
+ "slices"
"sync"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/permission"
@@ -195,9 +197,8 @@ func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *confi
}
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
- var mu sync.Mutex
var wg sync.WaitGroup
- var result []tools.BaseTool
+ result := csync.NewSlice[tools.BaseTool]()
for name, m := range cfg.MCP {
if m.Disabled {
slog.Debug("skipping disabled mcp", "name", name)
@@ -218,9 +219,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
return
}
- mu.Lock()
- result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- mu.Unlock()
+ result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
case config.MCPHttp:
c, err := client.NewStreamableHttpClient(
m.URL,
@@ -230,9 +229,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
slog.Error("error creating mcp client", "error", err)
return
}
- mu.Lock()
- result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- mu.Unlock()
+ result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
m.URL,
@@ -242,12 +239,10 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
slog.Error("error creating mcp client", "error", err)
return
}
- mu.Lock()
- result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- mu.Unlock()
+ result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
}
}(name, m)
}
wg.Wait()
- return result
+ return slices.Collect(result.Seq())
}
@@ -7,6 +7,7 @@ import (
"sync"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
)
@@ -74,8 +75,7 @@ func processContextPaths(workDir string, paths []string) string {
)
// Track processed files to avoid duplicates
- processedFiles := make(map[string]bool)
- var processedMutex sync.Mutex
+ processedFiles := csync.NewMap[string, bool]()
for _, path := range paths {
wg.Add(1)
@@ -106,14 +106,8 @@ func processContextPaths(workDir string, paths []string) string {
// Check if we've already processed this file (case-insensitive)
lowerPath := strings.ToLower(path)
- processedMutex.Lock()
- alreadyProcessed := processedFiles[lowerPath]
- if !alreadyProcessed {
- processedFiles[lowerPath] = true
- }
- processedMutex.Unlock()
-
- if !alreadyProcessed {
+ if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
+ processedFiles.Set(lowerPath, true)
if result := processFile(path); result != "" {
resultCh <- result
}
@@ -126,14 +120,8 @@ func processContextPaths(workDir string, paths []string) string {
// Check if we've already processed this file (case-insensitive)
lowerPath := strings.ToLower(fullPath)
- processedMutex.Lock()
- alreadyProcessed := processedFiles[lowerPath]
- if !alreadyProcessed {
- processedFiles[lowerPath] = true
- }
- processedMutex.Unlock()
-
- if !alreadyProcessed {
+ if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
+ processedFiles.Set(lowerPath, true)
result := processFile(fullPath)
if result != "" {
resultCh <- result
@@ -12,6 +12,7 @@ import (
"github.com/bmatcuk/doublestar/v4"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
@@ -25,8 +26,7 @@ type WorkspaceWatcher struct {
workspacePath string
debounceTime time.Duration
- debounceMap map[string]*time.Timer
- debounceMu sync.Mutex
+ debounceMap *csync.Map[string, *time.Timer]
// File watchers registered by the server
registrations []protocol.FileSystemWatcher
@@ -46,7 +46,7 @@ func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher {
name: name,
client: client,
debounceTime: 300 * time.Millisecond,
- debounceMap: make(map[string]*time.Timer),
+ debounceMap: csync.NewMap[string, *time.Timer](),
registrations: []protocol.FileSystemWatcher{},
}
}
@@ -635,26 +635,21 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
// debounceHandleFileEvent handles file events with debouncing to reduce notifications
func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) {
- w.debounceMu.Lock()
- defer w.debounceMu.Unlock()
-
// Create a unique key based on URI and change type
key := fmt.Sprintf("%s:%d", uri, changeType)
// Cancel existing timer if any
- if timer, exists := w.debounceMap[key]; exists {
+ if timer, exists := w.debounceMap.Get(key); exists {
timer.Stop()
}
// Create new timer
- w.debounceMap[key] = time.AfterFunc(w.debounceTime, func() {
+ w.debounceMap.Set(key, time.AfterFunc(w.debounceTime, func() {
w.handleFileEvent(ctx, uri, changeType)
// Cleanup timer after execution
- w.debounceMu.Lock()
- delete(w.debounceMap, key)
- w.debounceMu.Unlock()
- })
+ w.debounceMap.Del(key)
+ }))
}
// handleFileEvent sends file change notifications