fix(lsp): use csync for lsp clients (#1073)

Carlos Alexandro Becker created

The map was being passed down everywhere, but the locking mechanism only
ever lived in `app.go`, which might cause concurrent access issues.

This changes it to a `*csync.Map`.

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/app/app.go                             | 16 +++---------
internal/app/lsp.go                             |  4 --
internal/csync/maps.go                          |  8 +++---
internal/csync/versionedmap.go                  | 24 +++++++++++++++---
internal/llm/agent/agent.go                     |  7 ++---
internal/llm/tools/diagnostics.go               | 20 +++++++-------
internal/llm/tools/edit.go                      |  5 ++-
internal/llm/tools/multiedit.go                 |  5 ++-
internal/llm/tools/view.go                      |  5 ++-
internal/llm/tools/write.go                     |  5 ++-
internal/tui/components/chat/header/header.go   |  7 +++--
internal/tui/components/chat/sidebar/sidebar.go |  4 +-
internal/tui/components/lsp/lsp.go              |  7 +++--
13 files changed, 64 insertions(+), 53 deletions(-)

Detailed changes

internal/app/app.go 🔗

@@ -6,12 +6,12 @@ import (
 	"errors"
 	"fmt"
 	"log/slog"
-	"maps"
 	"sync"
 	"time"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/format"
 	"github.com/charmbracelet/crush/internal/history"
@@ -33,9 +33,7 @@ type App struct {
 
 	CoderAgent agent.Service
 
-	LSPClients map[string]*lsp.Client
-
-	clientsMutex sync.RWMutex
+	LSPClients *csync.Map[string, *lsp.Client]
 
 	config *config.Config
 
@@ -66,7 +64,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 		Messages:    messages,
 		History:     files,
 		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
-		LSPClients:  make(map[string]*lsp.Client),
+		LSPClients:  csync.NewMap[string, *lsp.Client](),
 
 		globalCtx: ctx,
 
@@ -324,14 +322,8 @@ func (app *App) Shutdown() {
 		app.CoderAgent.CancelAll()
 	}
 
-	// Get all LSP clients.
-	app.clientsMutex.RLock()
-	clients := make(map[string]*lsp.Client, len(app.LSPClients))
-	maps.Copy(clients, app.LSPClients)
-	app.clientsMutex.RUnlock()
-
 	// Shutdown all LSP clients.
-	for name, client := range clients {
+	for name, client := range app.LSPClients.Seq2() {
 		shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
 		if err := client.Close(shutdownCtx); err != nil {
 			slog.Error("Failed to shutdown LSP client", "name", name, "error", err)

internal/app/lsp.go 🔗

@@ -76,7 +76,5 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config
 	slog.Info("LSP client initialized", "name", name)
 
 	// Add to map with mutex protection before starting goroutine
-	app.clientsMutex.Lock()
-	app.LSPClients[name] = lspClient
-	app.clientsMutex.Unlock()
+	app.LSPClients.Set(name, lspClient)
 }

internal/csync/maps.go 🔗

@@ -70,10 +70,10 @@ func (m *Map[K, V]) GetOrSet(key K, fn func() V) V {
 
 // Take gets an item and then deletes it.
 func (m *Map[K, V]) Take(key K) (V, bool) {
-	v, ok := m.Get(key)
-	if ok {
-		m.Del(key)
-	}
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	v, ok := m.inner[key]
+	delete(m.inner, key)
 	return v, ok
 }
 

internal/csync/versionedmap.go 🔗

@@ -1,34 +1,50 @@
 package csync
 
 import (
+	"iter"
 	"sync/atomic"
 )
 
 // NewVersionedMap creates a new versioned, thread-safe map.
 func NewVersionedMap[K comparable, V any]() *VersionedMap[K, V] {
 	return &VersionedMap[K, V]{
-		Map: NewMap[K, V](),
+		m: NewMap[K, V](),
 	}
 }
 
 // VersionedMap is a thread-safe map that keeps track of its version.
 type VersionedMap[K comparable, V any] struct {
-	*Map[K, V]
+	m *Map[K, V]
 	v atomic.Uint64
 }
 
+// Get gets the value for the specified key from the map.
+func (m *VersionedMap[K, V]) Get(key K) (V, bool) {
+	return m.m.Get(key)
+}
+
 // Set sets the value for the specified key in the map and increments the version.
 func (m *VersionedMap[K, V]) Set(key K, value V) {
-	m.Map.Set(key, value)
+	m.m.Set(key, value)
 	m.v.Add(1)
 }
 
 // Del deletes the specified key from the map and increments the version.
 func (m *VersionedMap[K, V]) Del(key K) {
-	m.Map.Del(key)
+	m.m.Del(key)
 	m.v.Add(1)
 }
 
+// Seq2 returns an iter.Seq2 that yields key-value pairs from the map.
+func (m *VersionedMap[K, V]) Seq2() iter.Seq2[K, V] {
+	return m.m.Seq2()
+}
+
+// Len returns the number of items in the map.
+func (m *VersionedMap[K, V]) Len() int {
+	return m.m.Len()
+}
+
 // Version returns the current version of the map.
 func (m *VersionedMap[K, V]) Version() uint64 {
 	return m.v.Load()

internal/llm/agent/agent.go 🔗

@@ -83,8 +83,7 @@ type agent struct {
 	summarizeProviderID string
 
 	activeRequests *csync.Map[string, context.CancelFunc]
-
-	promptQueue *csync.Map[string, []string]
+	promptQueue    *csync.Map[string, []string]
 }
 
 var agentPromptMap = map[string]prompt.PromptID{
@@ -100,7 +99,7 @@ func NewAgent(
 	sessions session.Service,
 	messages message.Service,
 	history history.Service,
-	lspClients map[string]*lsp.Client,
+	lspClients *csync.Map[string, *lsp.Client],
 ) (Service, error) {
 	cfg := config.Get()
 
@@ -204,7 +203,7 @@ func NewAgent(
 		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 			if agentCfg.ID == "coder" {
 				t = append(t, mcpTools...)
-				if len(lspClients) > 0 {
+				if lspClients.Len() > 0 {
 					t = append(t, tools.NewDiagnosticsTool(lspClients))
 				}
 			}

internal/llm/tools/diagnostics.go 🔗

@@ -9,6 +9,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 )
@@ -18,7 +19,7 @@ type DiagnosticsParams struct {
 }
 
 type diagnosticsTool struct {
-	lspClients map[string]*lsp.Client
+	lspClients *csync.Map[string, *lsp.Client]
 }
 
 const (
@@ -46,7 +47,7 @@ TIPS:
 `
 )
 
-func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
+func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
 	return &diagnosticsTool{
 		lspClients,
 	}
@@ -76,20 +77,19 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 	}
 
-	lsps := b.lspClients
-	if len(lsps) == 0 {
+	if b.lspClients.Len() == 0 {
 		return NewTextErrorResponse("no LSP clients available"), nil
 	}
-	notifyLSPs(ctx, lsps, params.FilePath)
-	output := getDiagnostics(params.FilePath, lsps)
+	notifyLSPs(ctx, b.lspClients, params.FilePath)
+	output := getDiagnostics(params.FilePath, b.lspClients)
 	return NewTextResponse(output), nil
 }
 
-func notifyLSPs(ctx context.Context, lsps map[string]*lsp.Client, filepath string) {
+func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
 	if filepath == "" {
 		return
 	}
-	for _, client := range lsps {
+	for client := range lsps.Seq() {
 		if !client.HandlesFile(filepath) {
 			continue
 		}
@@ -99,11 +99,11 @@ func notifyLSPs(ctx context.Context, lsps map[string]*lsp.Client, filepath strin
 	}
 }
 
-func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
+func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
 	fileDiagnostics := []string{}
 	projectDiagnostics := []string{}
 
-	for lspName, client := range lsps {
+	for lspName, client := range lsps.Seq2() {
 		for location, diags := range client.GetDiagnostics() {
 			path, err := location.Path()
 			if err != nil {

internal/llm/tools/edit.go 🔗

@@ -10,6 +10,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
@@ -39,7 +40,7 @@ type EditResponseMetadata struct {
 }
 
 type editTool struct {
-	lspClients  map[string]*lsp.Client
+	lspClients  *csync.Map[string, *lsp.Client]
 	permissions permission.Service
 	files       history.Service
 	workingDir  string
@@ -104,7 +105,7 @@ WINDOWS NOTES:
 Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
 )
 
-func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool {
+func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 	return &editTool{
 		lspClients:  lspClients,
 		permissions: permissions,

internal/llm/tools/multiedit.go 🔗

@@ -10,6 +10,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
@@ -43,7 +44,7 @@ type MultiEditResponseMetadata struct {
 }
 
 type multiEditTool struct {
-	lspClients  map[string]*lsp.Client
+	lspClients  *csync.Map[string, *lsp.Client]
 	permissions permission.Service
 	files       history.Service
 	workingDir  string
@@ -95,7 +96,7 @@ If you want to create a new file, use:
 - Subsequent edits: normal edit operations on the created content`
 )
 
-func NewMultiEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool {
+func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 	return &multiEditTool{
 		lspClients:  lspClients,
 		permissions: permissions,

internal/llm/tools/view.go 🔗

@@ -11,6 +11,7 @@ import (
 	"strings"
 	"unicode/utf8"
 
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/permission"
 )
@@ -28,7 +29,7 @@ type ViewPermissionsParams struct {
 }
 
 type viewTool struct {
-	lspClients  map[string]*lsp.Client
+	lspClients  *csync.Map[string, *lsp.Client]
 	workingDir  string
 	permissions permission.Service
 }
@@ -81,7 +82,7 @@ TIPS:
 - When viewing large files, use the offset parameter to read specific sections`
 )
 
-func NewViewTool(lspClients map[string]*lsp.Client, permissions permission.Service, workingDir string) BaseTool {
+func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string) BaseTool {
 	return &viewTool{
 		lspClients:  lspClients,
 		workingDir:  workingDir,

internal/llm/tools/write.go 🔗

@@ -10,6 +10,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
@@ -30,7 +31,7 @@ type WritePermissionsParams struct {
 }
 
 type writeTool struct {
-	lspClients  map[string]*lsp.Client
+	lspClients  *csync.Map[string, *lsp.Client]
 	permissions permission.Service
 	files       history.Service
 	workingDir  string
@@ -78,7 +79,7 @@ TIPS:
 - Always include descriptive comments when making changes to existing code`
 )
 
-func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool {
+func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 	return &writeTool{
 		lspClients:  lspClients,
 		permissions: permissions,

internal/tui/components/chat/header/header.go 🔗

@@ -6,6 +6,7 @@ import (
 
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/pubsub"
@@ -28,11 +29,11 @@ type Header interface {
 type header struct {
 	width       int
 	session     session.Session
-	lspClients  map[string]*lsp.Client
+	lspClients  *csync.Map[string, *lsp.Client]
 	detailsOpen bool
 }
 
-func New(lspClients map[string]*lsp.Client) Header {
+func New(lspClients *csync.Map[string, *lsp.Client]) Header {
 	return &header{
 		lspClients: lspClients,
 		width:      0,
@@ -104,7 +105,7 @@ func (h *header) details(availWidth int) string {
 	var parts []string
 
 	errorCount := 0
-	for _, l := range h.lspClients {
+	for l := range h.lspClients.Seq() {
 		for _, diagnostics := range l.GetDiagnostics() {
 			for _, diagnostic := range diagnostics {
 				if diagnostic.Severity == protocol.SeverityError {

internal/tui/components/chat/sidebar/sidebar.go 🔗

@@ -69,13 +69,13 @@ type sidebarCmp struct {
 	session       session.Session
 	logo          string
 	cwd           string
-	lspClients    map[string]*lsp.Client
+	lspClients    *csync.Map[string, *lsp.Client]
 	compactMode   bool
 	history       history.Service
 	files         *csync.Map[string, SessionFile]
 }
 
-func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar {
+func New(history history.Service, lspClients *csync.Map[string, *lsp.Client], compact bool) Sidebar {
 	return &sidebarCmp{
 		lspClients:  lspClients,
 		history:     history,

internal/tui/components/lsp/lsp.go 🔗

@@ -6,6 +6,7 @@ import (
 
 	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/tui/components/core"
 	"github.com/charmbracelet/crush/internal/tui/styles"
@@ -22,7 +23,7 @@ type RenderOptions struct {
 }
 
 // RenderLSPList renders a list of LSP status items with the given options.
-func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []string {
+func RenderLSPList(lspClients *csync.Map[string, *lsp.Client], opts RenderOptions) []string {
 	t := styles.CurrentTheme()
 	lspList := []string{}
 
@@ -91,7 +92,7 @@ func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []stri
 				protocol.SeverityHint:        0,
 				protocol.SeverityInformation: 0,
 			}
-			if client, ok := lspClients[l.Name]; ok {
+			if client, ok := lspClients.Get(l.Name); ok {
 				for _, diagnostics := range client.GetDiagnostics() {
 					for _, diagnostic := range diagnostics {
 						if severity, ok := lspErrs[diagnostic.Severity]; ok {
@@ -134,7 +135,7 @@ func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []stri
 }
 
 // RenderLSPBlock renders a complete LSP block with optional truncation indicator.
-func RenderLSPBlock(lspClients map[string]*lsp.Client, opts RenderOptions, showTruncationIndicator bool) string {
+func RenderLSPBlock(lspClients *csync.Map[string, *lsp.Client], opts RenderOptions, showTruncationIndicator bool) string {
 	t := styles.CurrentTheme()
 	lspList := RenderLSPList(lspClients, opts)