diff --git a/internal/app/app.go b/internal/app/app.go index b93ba2cc9cbb16569c7c4739192dad1517581f57..2b3d81fb58acdeb2570a765c0a25ec53b65121da 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -6,12 +6,12 @@ import ( "errors" "fmt" "log/slog" - "maps" "sync" "time" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -33,9 +33,7 @@ type App struct { CoderAgent agent.Service - LSPClients map[string]*lsp.Client - - clientsMutex sync.RWMutex + LSPClients *csync.Map[string, *lsp.Client] config *config.Config @@ -66,7 +64,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { Messages: messages, History: files, Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), - LSPClients: make(map[string]*lsp.Client), + LSPClients: csync.NewMap[string, *lsp.Client](), globalCtx: ctx, @@ -324,14 +322,8 @@ func (app *App) Shutdown() { app.CoderAgent.CancelAll() } - // Get all LSP clients. - app.clientsMutex.RLock() - clients := make(map[string]*lsp.Client, len(app.LSPClients)) - maps.Copy(clients, app.LSPClients) - app.clientsMutex.RUnlock() - // Shutdown all LSP clients. - for name, client := range clients { + for name, client := range app.LSPClients.Seq2() { shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second) if err := client.Close(shutdownCtx); err != nil { slog.Error("Failed to shutdown LSP client", "name", name, "error", err) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4a6932f275564139bd91e83467d6e5224083e5b5..057e9ce39363f3fd68c8c980ce22e3e8b0e78154 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -76,7 +76,5 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config slog.Info("LSP client initialized", "name", name) // Add to map with mutex protection before starting goroutine - app.clientsMutex.Lock() - app.LSPClients[name] = lspClient - app.clientsMutex.Unlock() + app.LSPClients.Set(name, lspClient) } diff --git a/internal/csync/maps.go b/internal/csync/maps.go index 14e8b36c9c37ae2d93c9771e424579051f5181c8..b7a1f3109f6c15e7e5592cb538943a2d9e340819 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -70,10 +70,10 @@ func (m *Map[K, V]) GetOrSet(key K, fn func() V) V { // Take gets an item and then deletes it. func (m *Map[K, V]) Take(key K) (V, bool) { - v, ok := m.Get(key) - if ok { - m.Del(key) - } + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.inner[key] + delete(m.inner, key) return v, ok } diff --git a/internal/csync/versionedmap.go b/internal/csync/versionedmap.go index dfe2d6f5e893f73cc34cfd99fab984dcc273cd9a..f0f4e0249c3b0102976840bd82400e18c1703c47 100644 --- a/internal/csync/versionedmap.go +++ b/internal/csync/versionedmap.go @@ -1,34 +1,50 @@ package csync import ( + "iter" "sync/atomic" ) // NewVersionedMap creates a new versioned, thread-safe map. func NewVersionedMap[K comparable, V any]() *VersionedMap[K, V] { return &VersionedMap[K, V]{ - Map: NewMap[K, V](), + m: NewMap[K, V](), } } // VersionedMap is a thread-safe map that keeps track of its version. type VersionedMap[K comparable, V any] struct { - *Map[K, V] + m *Map[K, V] v atomic.Uint64 } +// Get gets the value for the specified key from the map. +func (m *VersionedMap[K, V]) Get(key K) (V, bool) { + return m.m.Get(key) +} + // Set sets the value for the specified key in the map and increments the version. func (m *VersionedMap[K, V]) Set(key K, value V) { - m.Map.Set(key, value) + m.m.Set(key, value) m.v.Add(1) } // Del deletes the specified key from the map and increments the version. func (m *VersionedMap[K, V]) Del(key K) { - m.Map.Del(key) + m.m.Del(key) m.v.Add(1) } +// Seq2 returns an iter.Seq2 that yields key-value pairs from the map. +func (m *VersionedMap[K, V]) Seq2() iter.Seq2[K, V] { + return m.m.Seq2() +} + +// Len returns the number of items in the map. +func (m *VersionedMap[K, V]) Len() int { + return m.m.Len() +} + // Version returns the current version of the map. func (m *VersionedMap[K, V]) Version() uint64 { return m.v.Load() diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 7c09a0be621485962df43e82484b0add4ea63513..864188113168948c2e59a221c62c6cdad99f75ce 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -83,8 +83,7 @@ type agent struct { summarizeProviderID string activeRequests *csync.Map[string, context.CancelFunc] - - promptQueue *csync.Map[string, []string] + promptQueue *csync.Map[string, []string] } var agentPromptMap = map[string]prompt.PromptID{ @@ -100,7 +99,7 @@ func NewAgent( sessions session.Service, messages message.Service, history history.Service, - lspClients map[string]*lsp.Client, + lspClients *csync.Map[string, *lsp.Client], ) (Service, error) { cfg := config.Get() @@ -204,7 +203,7 @@ func NewAgent( withCoderTools := func(t []tools.BaseTool) []tools.BaseTool { if agentCfg.ID == "coder" { t = append(t, mcpTools...) - if len(lspClients) > 0 { + if lspClients.Len() > 0 { t = append(t, tools.NewDiagnosticsTool(lspClients)) } } diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index 527e2f786895230db41784d0cb1b643b0f40f71c..17b93ab07cae29f2a274c0e289be02ac10827af2 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" ) @@ -18,7 +19,7 @@ type DiagnosticsParams struct { } type diagnosticsTool struct { - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] } const ( @@ -46,7 +47,7 @@ TIPS: ` ) -func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { +func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool { return &diagnosticsTool{ lspClients, } @@ -76,20 +77,19 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } - lsps := b.lspClients - if len(lsps) == 0 { + if b.lspClients.Len() == 0 { return NewTextErrorResponse("no LSP clients available"), nil } - notifyLSPs(ctx, lsps, params.FilePath) - output := getDiagnostics(params.FilePath, lsps) + notifyLSPs(ctx, b.lspClients, params.FilePath) + output := getDiagnostics(params.FilePath, b.lspClients) return NewTextResponse(output), nil } -func notifyLSPs(ctx context.Context, lsps map[string]*lsp.Client, filepath string) { +func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) { if filepath == "" { return } - for _, client := range lsps { + for client := range lsps.Seq() { if !client.HandlesFile(filepath) { continue } @@ -99,11 +99,11 @@ func notifyLSPs(ctx context.Context, lsps map[string]*lsp.Client, filepath strin } } -func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string { +func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string { fileDiagnostics := []string{} projectDiagnostics := []string{} - for lspName, client := range lsps { + for lspName, client := range lsps.Seq2() { for location, diags := range client.GetDiagnostics() { path, err := location.Path() if err != nil { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 1afa03a427c36c7fe6ad448f4183f7ff4636ef85..d819ceb0af54b5682aecda703850a7b5a795e97c 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -39,7 +40,7 @@ type EditResponseMetadata struct { } type editTool struct { - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] permissions permission.Service files history.Service workingDir string @@ -104,7 +105,7 @@ WINDOWS NOTES: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { +func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, diff --git a/internal/llm/tools/multiedit.go b/internal/llm/tools/multiedit.go index 2e08e973ba9eb46910fd39e98207b2f5e7bcca1f..4f99070b1a030e9c8f741f0671a6b2254899f276 100644 --- a/internal/llm/tools/multiedit.go +++ b/internal/llm/tools/multiedit.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -43,7 +44,7 @@ type MultiEditResponseMetadata struct { } type multiEditTool struct { - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] permissions permission.Service files history.Service workingDir string @@ -95,7 +96,7 @@ If you want to create a new file, use: - Subsequent edits: normal edit operations on the created content` ) -func NewMultiEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { +func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { return &multiEditTool{ lspClients: lspClients, permissions: permissions, diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 5664edf0baf01f448f1b92ffed6c3e213ee608c2..7e48a91d380a693295a130b0b39e47c685aab142 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -11,6 +11,7 @@ import ( "strings" "unicode/utf8" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/permission" ) @@ -28,7 +29,7 @@ type ViewPermissionsParams struct { } type viewTool struct { - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] workingDir string permissions permission.Service } @@ -81,7 +82,7 @@ TIPS: - When viewing large files, use the offset parameter to read specific sections` ) -func NewViewTool(lspClients map[string]*lsp.Client, permissions permission.Service, workingDir string) BaseTool { +func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string) BaseTool { return &viewTool{ lspClients: lspClients, workingDir: workingDir, diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 6bbabba93d1dcf7064789bddd9fe4bc69e9f9182..cb256eb3d5c016797635796c8a8cf706810161af 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" @@ -30,7 +31,7 @@ type WritePermissionsParams struct { } type writeTool struct { - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] permissions permission.Service files history.Service workingDir string @@ -78,7 +79,7 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { +func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index 5e5a68b5290187cea95b7cf8c0aada6cb46b4415..21861a4a2eda1340f6e01c0748f24cb713f15398 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -6,6 +6,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/pubsub" @@ -28,11 +29,11 @@ type Header interface { type header struct { width int session session.Session - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] detailsOpen bool } -func New(lspClients map[string]*lsp.Client) Header { +func New(lspClients *csync.Map[string, *lsp.Client]) Header { return &header{ lspClients: lspClients, width: 0, @@ -104,7 +105,7 @@ func (h *header) details(availWidth int) string { var parts []string errorCount := 0 - for _, l := range h.lspClients { + for l := range h.lspClients.Seq() { for _, diagnostics := range l.GetDiagnostics() { for _, diagnostic := range diagnostics { if diagnostic.Severity == protocol.SeverityError { diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 236c5d2e31c6e7f81482757ff750f572e23cc3fb..b50a78c7f8697e4f4db19649a01794cfe7a23bac 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -69,13 +69,13 @@ type sidebarCmp struct { session session.Session logo string cwd string - lspClients map[string]*lsp.Client + lspClients *csync.Map[string, *lsp.Client] compactMode bool history history.Service files *csync.Map[string, SessionFile] } -func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar { +func New(history history.Service, lspClients *csync.Map[string, *lsp.Client], compact bool) Sidebar { return &sidebarCmp{ lspClients: lspClients, history: history, diff --git a/internal/tui/components/lsp/lsp.go b/internal/tui/components/lsp/lsp.go index 53daeb0a65c43a1e4ae80ff6567c7daa32a800b8..f5f4061045901c91ecb8bce1f47eab3ac1f7abcf 100644 --- a/internal/tui/components/lsp/lsp.go +++ b/internal/tui/components/lsp/lsp.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/styles" @@ -22,7 +23,7 @@ type RenderOptions struct { } // RenderLSPList renders a list of LSP status items with the given options. -func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []string { +func RenderLSPList(lspClients *csync.Map[string, *lsp.Client], opts RenderOptions) []string { t := styles.CurrentTheme() lspList := []string{} @@ -91,7 +92,7 @@ func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []stri protocol.SeverityHint: 0, protocol.SeverityInformation: 0, } - if client, ok := lspClients[l.Name]; ok { + if client, ok := lspClients.Get(l.Name); ok { for _, diagnostics := range client.GetDiagnostics() { for _, diagnostic := range diagnostics { if severity, ok := lspErrs[diagnostic.Severity]; ok { @@ -134,7 +135,7 @@ func RenderLSPList(lspClients map[string]*lsp.Client, opts RenderOptions) []stri } // RenderLSPBlock renders a complete LSP block with optional truncation indicator. -func RenderLSPBlock(lspClients map[string]*lsp.Client, opts RenderOptions, showTruncationIndicator bool) string { +func RenderLSPBlock(lspClients *csync.Map[string, *lsp.Client], opts RenderOptions, showTruncationIndicator bool) string { t := styles.CurrentTheme() lspList := RenderLSPList(lspClients, opts)