From 61ea243e489bb519de123feab56986f15347ce4c Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 25 Jul 2025 12:07:58 -0300 Subject: [PATCH] feat: slices as well --- internal/app/app.go | 10 +- internal/app/lsp.go | 4 +- internal/csync/maps.go | 6 +- internal/csync/slices.go | 127 +++++++++++++++++++ internal/csync/slices_test.go | 209 ++++++++++++++++++++++++++++++++ internal/llm/agent/mcp-tools.go | 19 ++- internal/llm/prompt/prompt.go | 24 +--- internal/lsp/watcher/watcher.go | 19 ++- 8 files changed, 363 insertions(+), 55 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 50e117ea1ae272156dbd11baa1a5f157a74333f1..67bfd8f5d7f8cd6b8b54a354a426b7fa3b0b01bb 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -12,6 +12,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -37,8 +38,7 @@ type App struct { clientsMutex sync.RWMutex - watcherCancelFuncs []context.CancelFunc - cancelFuncsMutex sync.Mutex + watcherCancelFuncs *csync.Slice[context.CancelFunc] lspWatcherWG sync.WaitGroup config *config.Config @@ -76,6 +76,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { config: cfg, + watcherCancelFuncs: csync.NewSlice[context.CancelFunc](), + events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, tuiWG: &sync.WaitGroup{}, @@ -305,11 +307,9 @@ func (app *App) Shutdown() { app.CoderAgent.CancelAll() } - app.cancelFuncsMutex.Lock() - for _, cancel := range app.watcherCancelFuncs { + for cancel := range app.watcherCancelFuncs.Seq() { cancel() } - app.cancelFuncsMutex.Unlock() // Wait for all LSP watchers to finish. app.lspWatcherWG.Wait() diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 946a373e5a7a69dc78fcbcc894629ecf3e9485ac..afe76a68460d262a3f57f214ad3c0c153ddbd807 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -63,9 +63,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient) // Store the cancel function to be called during cleanup. - app.cancelFuncsMutex.Lock() - app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc) - app.cancelFuncsMutex.Unlock() + app.watcherCancelFuncs.Append(cancelFunc) // Add to map with mutex protection before starting goroutine app.clientsMutex.Lock() diff --git a/internal/csync/maps.go b/internal/csync/maps.go index ddc735b2624500899ca25670f7934326dfd9bdf3..108c8a4cbb6f855687d6117b1764b85e27279bc9 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -82,12 +82,8 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { // Seq returns an iter.Seq that yields values from the map. func (m *Map[K, V]) Seq() iter.Seq[V] { - dst := make(map[K]V) - m.mu.RLock() - maps.Copy(dst, m.inner) - m.mu.RUnlock() return func(yield func(V) bool) { - for _, v := range dst { + for _, v := range m.Seq2() { if !yield(v) { return } diff --git a/internal/csync/slices.go b/internal/csync/slices.go index be723655079ccc6b07f55c3237b706a17bb14d40..3913a054c166c2bd29b3fafb7e6a0fa1998463a8 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -2,6 +2,7 @@ package csync import ( "iter" + "slices" "sync" ) @@ -34,3 +35,129 @@ func (s *LazySlice[K]) Seq() iter.Seq[K] { } } } + +// Slice is a thread-safe slice implementation that provides concurrent access. +type Slice[T any] struct { + inner []T + mu sync.RWMutex +} + +// NewSlice creates a new thread-safe slice. +func NewSlice[T any]() *Slice[T] { + return &Slice[T]{ + inner: make([]T, 0), + } +} + +// NewSliceFrom creates a new thread-safe slice from an existing slice. +func NewSliceFrom[T any](s []T) *Slice[T] { + inner := make([]T, len(s)) + copy(inner, s) + return &Slice[T]{ + inner: inner, + } +} + +// Append adds an element to the end of the slice. +func (s *Slice[T]) Append(items ...T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = append(s.inner, items...) +} + +// Prepend adds an element to the beginning of the slice. +func (s *Slice[T]) Prepend(item T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = append([]T{item}, s.inner...) +} + +// Delete removes the element at the specified index. +func (s *Slice[T]) Delete(index int) bool { + s.mu.Lock() + defer s.mu.Unlock() + if index < 0 || index >= len(s.inner) { + return false + } + s.inner = slices.Delete(s.inner, index, index+1) + return true +} + +// Get returns the element at the specified index. +func (s *Slice[T]) Get(index int) (T, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + var zero T + if index < 0 || index >= len(s.inner) { + return zero, false + } + return s.inner[index], true +} + +// Set updates the element at the specified index. +func (s *Slice[T]) Set(index int, item T) bool { + s.mu.Lock() + defer s.mu.Unlock() + if index < 0 || index >= len(s.inner) { + return false + } + s.inner[index] = item + return true +} + +// Len returns the number of elements in the slice. +func (s *Slice[T]) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.inner) +} + +// Slice returns a copy of the underlying slice. +func (s *Slice[T]) Slice() []T { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]T, len(s.inner)) + copy(result, s.inner) + return result +} + +// SetSlice replaces the entire slice with a new one. +func (s *Slice[T]) SetSlice(items []T) { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = make([]T, len(items)) + copy(s.inner, items) +} + +// Clear removes all elements from the slice. +func (s *Slice[T]) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + s.inner = s.inner[:0] +} + +// Seq returns an iterator that yields elements from the slice. +func (s *Slice[T]) Seq() iter.Seq[T] { + return func(yield func(T) bool) { + for _, v := range s.Seq2() { + if !yield(v) { + return + } + } + } +} + +// Seq2 returns an iterator that yields index-value pairs from the slice. +func (s *Slice[T]) Seq2() iter.Seq2[int, T] { + s.mu.RLock() + items := make([]T, len(s.inner)) + copy(items, s.inner) + s.mu.RUnlock() + return func(yield func(int, T) bool) { + for i, v := range items { + if !yield(i, v) { + return + } + } + } +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 731cb96f55dd24cae74f55c0ef8e97ebd28aacaa..d86b02537fe0534ed60de0e314e1e1d9b7b866b6 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -1,11 +1,13 @@ package csync import ( + "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLazySlice_Seq(t *testing.T) { @@ -85,3 +87,210 @@ func TestLazySlice_EarlyBreak(t *testing.T) { assert.Equal(t, []string{"a", "b"}, result) } + +func TestSlice(t *testing.T) { + t.Run("NewSlice", func(t *testing.T) { + s := NewSlice[int]() + assert.Equal(t, 0, s.Len()) + }) + + t.Run("NewSliceFrom", func(t *testing.T) { + original := []int{1, 2, 3} + s := NewSliceFrom(original) + assert.Equal(t, 3, s.Len()) + + // Verify it's a copy, not a reference + original[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 1, val) + }) + + t.Run("Append", func(t *testing.T) { + s := NewSlice[string]() + s.Append("hello") + s.Append("world") + + assert.Equal(t, 2, s.Len()) + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, "hello", val) + + val, ok = s.Get(1) + require.True(t, ok) + assert.Equal(t, "world", val) + }) + + t.Run("Prepend", func(t *testing.T) { + s := NewSlice[string]() + s.Append("world") + s.Prepend("hello") + + assert.Equal(t, 2, s.Len()) + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, "hello", val) + + val, ok = s.Get(1) + require.True(t, ok) + assert.Equal(t, "world", val) + }) + + t.Run("Delete", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3, 4, 5}) + + // Delete middle element + ok := s.Delete(2) + assert.True(t, ok) + assert.Equal(t, 4, s.Len()) + + expected := []int{1, 2, 4, 5} + actual := s.Slice() + assert.Equal(t, expected, actual) + + // Delete out of bounds + ok = s.Delete(10) + assert.False(t, ok) + assert.Equal(t, 4, s.Len()) + + // Delete negative index + ok = s.Delete(-1) + assert.False(t, ok) + assert.Equal(t, 4, s.Len()) + }) + + t.Run("Get", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + val, ok := s.Get(1) + require.True(t, ok) + assert.Equal(t, "b", val) + + // Out of bounds + _, ok = s.Get(10) + assert.False(t, ok) + + // Negative index + _, ok = s.Get(-1) + assert.False(t, ok) + }) + + t.Run("Set", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + ok := s.Set(1, "modified") + assert.True(t, ok) + + val, ok := s.Get(1) + require.True(t, ok) + assert.Equal(t, "modified", val) + + // Out of bounds + ok = s.Set(10, "invalid") + assert.False(t, ok) + + // Negative index + ok = s.Set(-1, "invalid") + assert.False(t, ok) + }) + + t.Run("SetSlice", func(t *testing.T) { + s := NewSlice[int]() + s.Append(1) + s.Append(2) + + newItems := []int{10, 20, 30} + s.SetSlice(newItems) + + assert.Equal(t, 3, s.Len()) + assert.Equal(t, newItems, s.Slice()) + + // Verify it's a copy + newItems[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 10, val) + }) + + t.Run("Clear", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3}) + assert.Equal(t, 3, s.Len()) + + s.Clear() + assert.Equal(t, 0, s.Len()) + }) + + t.Run("Slice", func(t *testing.T) { + original := []int{1, 2, 3} + s := NewSliceFrom(original) + + copy := s.Slice() + assert.Equal(t, original, copy) + + // Verify it's a copy + copy[0] = 999 + val, ok := s.Get(0) + require.True(t, ok) + assert.Equal(t, 1, val) + }) + + t.Run("Seq", func(t *testing.T) { + s := NewSliceFrom([]int{1, 2, 3}) + + var result []int + for v := range s.Seq() { + result = append(result, v) + } + + assert.Equal(t, []int{1, 2, 3}, result) + }) + + t.Run("SeqWithIndex", func(t *testing.T) { + s := NewSliceFrom([]string{"a", "b", "c"}) + + var indices []int + var values []string + for i, v := range s.Seq2() { + indices = append(indices, i) + values = append(values, v) + } + + assert.Equal(t, []int{0, 1, 2}, indices) + assert.Equal(t, []string{"a", "b", "c"}, values) + }) + + t.Run("ConcurrentAccess", func(t *testing.T) { + s := NewSlice[int]() + const numGoroutines = 100 + const itemsPerGoroutine = 10 + + var wg sync.WaitGroup + + // Concurrent appends + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(start int) { + defer wg.Done() + for j := 0; j < itemsPerGoroutine; j++ { + s.Append(start*itemsPerGoroutine + j) + } + }(i) + } + + // Concurrent reads + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerGoroutine; j++ { + s.Len() // Just read the length + } + }() + } + + wg.Wait() + + // Should have all items + assert.Equal(t, numGoroutines*itemsPerGoroutine, s.Len()) + }) +} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 0165b0f7194d029a6dee9113f82877820ce96c00..8a462eb1496bc6501f6f96d43307aec65eb40e97 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" + "slices" "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/permission" @@ -195,9 +197,8 @@ func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *confi } func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { - var mu sync.Mutex var wg sync.WaitGroup - var result []tools.BaseTool + result := csync.NewSlice[tools.BaseTool]() for name, m := range cfg.MCP { if m.Disabled { slog.Debug("skipping disabled mcp", "name", name) @@ -218,9 +219,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPHttp: c, err := client.NewStreamableHttpClient( m.URL, @@ -230,9 +229,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPSse: c, err := client.NewSSEMCPClient( m.URL, @@ -242,12 +239,10 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - mu.Lock() - result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - mu.Unlock() + result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) } }(name, m) } wg.Wait() - return result + return slices.Collect(result.Seq()) } diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 4a2661bb9f663d9f93cf0371ac5d71dd513392c7..8c87482a71679f5bc682e6fdd8c1f5a03b89c184 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" ) @@ -74,8 +75,7 @@ func processContextPaths(workDir string, paths []string) string { ) // Track processed files to avoid duplicates - processedFiles := make(map[string]bool) - var processedMutex sync.Mutex + processedFiles := csync.NewMap[string, bool]() for _, path := range paths { wg.Add(1) @@ -106,14 +106,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(path) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) if result := processFile(path); result != "" { resultCh <- result } @@ -126,14 +120,8 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) - processedMutex.Lock() - alreadyProcessed := processedFiles[lowerPath] - if !alreadyProcessed { - processedFiles[lowerPath] = true - } - processedMutex.Unlock() - - if !alreadyProcessed { + if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { + processedFiles.Set(lowerPath, true) result := processFile(fullPath) if result != "" { resultCh <- result diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 080870937bee98be852979748dab456fa6a53b66..43cbd06dcc2136420385f40f97fb5e4e579c539b 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -12,6 +12,7 @@ import ( "github.com/bmatcuk/doublestar/v4" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -25,8 +26,7 @@ type WorkspaceWatcher struct { workspacePath string debounceTime time.Duration - debounceMap map[string]*time.Timer - debounceMu sync.Mutex + debounceMap *csync.Map[string, *time.Timer] // File watchers registered by the server registrations []protocol.FileSystemWatcher @@ -46,7 +46,7 @@ func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher { name: name, client: client, debounceTime: 300 * time.Millisecond, - debounceMap: make(map[string]*time.Timer), + debounceMap: csync.NewMap[string, *time.Timer](), registrations: []protocol.FileSystemWatcher{}, } } @@ -635,26 +635,21 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt // debounceHandleFileEvent handles file events with debouncing to reduce notifications func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) { - w.debounceMu.Lock() - defer w.debounceMu.Unlock() - // Create a unique key based on URI and change type key := fmt.Sprintf("%s:%d", uri, changeType) // Cancel existing timer if any - if timer, exists := w.debounceMap[key]; exists { + if timer, exists := w.debounceMap.Get(key); exists { timer.Stop() } // Create new timer - w.debounceMap[key] = time.AfterFunc(w.debounceTime, func() { + w.debounceMap.Set(key, time.AfterFunc(w.debounceTime, func() { w.handleFileEvent(ctx, uri, changeType) // Cleanup timer after execution - w.debounceMu.Lock() - delete(w.debounceMap, key) - w.debounceMu.Unlock() - }) + w.debounceMap.Del(key) + })) } // handleFileEvent sends file change notifications