lsp_restart.go

 1package tools
 2
 3import (
 4	"context"
 5	_ "embed"
 6	"fmt"
 7	"log/slog"
 8	"maps"
 9	"strings"
10	"sync"
11
12	"charm.land/fantasy"
13	"github.com/charmbracelet/crush/internal/csync"
14	"github.com/charmbracelet/crush/internal/lsp"
15)
16
17const LSPRestartToolName = "lsp_restart"
18
19//go:embed lsp_restart.md
20var lspRestartDescription []byte
21
22type LSPRestartParams struct {
23	// Name is the optional name of a specific LSP client to restart.
24	// If empty, all LSP clients will be restarted.
25	Name string `json:"name,omitempty"`
26}
27
28func NewLSPRestartTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
29	return fantasy.NewAgentTool(
30		LSPRestartToolName,
31		string(lspRestartDescription),
32		func(ctx context.Context, params LSPRestartParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
33			if lspClients.Len() == 0 {
34				return fantasy.NewTextErrorResponse("no LSP clients available to restart"), nil
35			}
36
37			clientsToRestart := make(map[string]*lsp.Client)
38			if params.Name == "" {
39				maps.Insert(clientsToRestart, lspClients.Seq2())
40			} else {
41				client, exists := lspClients.Get(params.Name)
42				if !exists {
43					return fantasy.NewTextErrorResponse(fmt.Sprintf("LSP client '%s' not found", params.Name)), nil
44				}
45				clientsToRestart[params.Name] = client
46			}
47
48			var restarted []string
49			var failed []string
50			var mu sync.Mutex
51			var wg sync.WaitGroup
52			for name, client := range clientsToRestart {
53				wg.Go(func() {
54					if err := client.Restart(); err != nil {
55						slog.Error("Failed to restart LSP client", "name", name, "error", err)
56						mu.Lock()
57						failed = append(failed, name)
58						mu.Unlock()
59						return
60					}
61					mu.Lock()
62					restarted = append(restarted, name)
63					mu.Unlock()
64				})
65			}
66
67			wg.Wait()
68
69			var output string
70			if len(restarted) > 0 {
71				output = fmt.Sprintf("Successfully restarted %d LSP client(s): %s\n", len(restarted), strings.Join(restarted, ", "))
72			}
73			if len(failed) > 0 {
74				output += fmt.Sprintf("Failed to restart %d LSP client(s): %s\n", len(failed), strings.Join(failed, ", "))
75				return fantasy.NewTextErrorResponse(output), nil
76			}
77
78			return fantasy.NewTextResponse(output), nil
79		})
80}