1package lsp
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "strings"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "github.com/kujtimiihoxha/termai/internal/config"
17 "github.com/kujtimiihoxha/termai/internal/logging"
18 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
19)
20
21var logger = logging.Get()
22
23type Client struct {
24 Cmd *exec.Cmd
25 stdin io.WriteCloser
26 stdout *bufio.Reader
27 stderr io.ReadCloser
28
29 // Request ID counter
30 nextID atomic.Int32
31
32 // Response handlers
33 handlers map[int32]chan *Message
34 handlersMu sync.RWMutex
35
36 // Server request handlers
37 serverRequestHandlers map[string]ServerRequestHandler
38 serverHandlersMu sync.RWMutex
39
40 // Notification handlers
41 notificationHandlers map[string]NotificationHandler
42 notificationMu sync.RWMutex
43
44 // Diagnostic cache
45 diagnostics map[protocol.DocumentUri][]protocol.Diagnostic
46 diagnosticsMu sync.RWMutex
47
48 // Files are currently opened by the LSP
49 openFiles map[string]*OpenFileInfo
50 openFilesMu sync.RWMutex
51}
52
53func NewClient(command string, args ...string) (*Client, error) {
54 cmd := exec.Command(command, args...)
55 // Copy env
56 cmd.Env = os.Environ()
57
58 stdin, err := cmd.StdinPipe()
59 if err != nil {
60 return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
61 }
62
63 stdout, err := cmd.StdoutPipe()
64 if err != nil {
65 return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
66 }
67
68 stderr, err := cmd.StderrPipe()
69 if err != nil {
70 return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
71 }
72
73 client := &Client{
74 Cmd: cmd,
75 stdin: stdin,
76 stdout: bufio.NewReader(stdout),
77 stderr: stderr,
78 handlers: make(map[int32]chan *Message),
79 notificationHandlers: make(map[string]NotificationHandler),
80 serverRequestHandlers: make(map[string]ServerRequestHandler),
81 diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),
82 openFiles: make(map[string]*OpenFileInfo),
83 }
84
85 // Start the LSP server process
86 if err := cmd.Start(); err != nil {
87 return nil, fmt.Errorf("failed to start LSP server: %w", err)
88 }
89
90 // Handle stderr in a separate goroutine
91 go func() {
92 scanner := bufio.NewScanner(stderr)
93 for scanner.Scan() {
94 fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
95 }
96 if err := scanner.Err(); err != nil {
97 fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
98 }
99 }()
100
101 // Start message handling loop
102 go client.handleMessages()
103
104 return client, nil
105}
106
107func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
108 c.notificationMu.Lock()
109 defer c.notificationMu.Unlock()
110 c.notificationHandlers[method] = handler
111}
112
113func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
114 c.serverHandlersMu.Lock()
115 defer c.serverHandlersMu.Unlock()
116 c.serverRequestHandlers[method] = handler
117}
118
119func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
120 initParams := &protocol.InitializeParams{
121 WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
122 WorkspaceFolders: []protocol.WorkspaceFolder{
123 {
124 URI: protocol.URI("file://" + workspaceDir),
125 Name: workspaceDir,
126 },
127 },
128 },
129
130 XInitializeParams: protocol.XInitializeParams{
131 ProcessID: int32(os.Getpid()),
132 ClientInfo: &protocol.ClientInfo{
133 Name: "mcp-language-server",
134 Version: "0.1.0",
135 },
136 RootPath: workspaceDir,
137 RootURI: protocol.DocumentUri("file://" + workspaceDir),
138 Capabilities: protocol.ClientCapabilities{
139 Workspace: protocol.WorkspaceClientCapabilities{
140 Configuration: true,
141 DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
142 DynamicRegistration: true,
143 },
144 DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
145 DynamicRegistration: true,
146 RelativePatternSupport: true,
147 },
148 },
149 TextDocument: protocol.TextDocumentClientCapabilities{
150 Synchronization: &protocol.TextDocumentSyncClientCapabilities{
151 DynamicRegistration: true,
152 DidSave: true,
153 },
154 Completion: protocol.CompletionClientCapabilities{
155 CompletionItem: protocol.ClientCompletionItemOptions{},
156 },
157 CodeLens: &protocol.CodeLensClientCapabilities{
158 DynamicRegistration: true,
159 },
160 DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
161 CodeAction: protocol.CodeActionClientCapabilities{
162 CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
163 CodeActionKind: protocol.ClientCodeActionKindOptions{
164 ValueSet: []protocol.CodeActionKind{},
165 },
166 },
167 },
168 PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
169 VersionSupport: true,
170 },
171 SemanticTokens: protocol.SemanticTokensClientCapabilities{
172 Requests: protocol.ClientSemanticTokensRequestOptions{
173 Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
174 Full: &protocol.Or_ClientSemanticTokensRequestOptions_full{},
175 },
176 TokenTypes: []string{},
177 TokenModifiers: []string{},
178 Formats: []protocol.TokenFormat{},
179 },
180 },
181 Window: protocol.WindowClientCapabilities{},
182 },
183 InitializationOptions: map[string]any{
184 "codelenses": map[string]bool{
185 "generate": true,
186 "regenerate_cgo": true,
187 "test": true,
188 "tidy": true,
189 "upgrade_dependency": true,
190 "vendor": true,
191 "vulncheck": false,
192 },
193 },
194 },
195 }
196
197 var result protocol.InitializeResult
198 if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
199 return nil, fmt.Errorf("initialize failed: %w", err)
200 }
201
202 if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
203 return nil, fmt.Errorf("initialized notification failed: %w", err)
204 }
205
206 // Register handlers
207 c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
208 c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
209 c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
210 c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
211 c.RegisterNotificationHandler("textDocument/publishDiagnostics",
212 func(params json.RawMessage) { HandleDiagnostics(c, params) })
213
214 // Notify the LSP server
215 err := c.Initialized(ctx, protocol.InitializedParams{})
216 if err != nil {
217 return nil, fmt.Errorf("initialization failed: %w", err)
218 }
219
220 // LSP sepecific Initialization
221 path := strings.ToLower(c.Cmd.Path)
222 switch {
223 case strings.Contains(path, "typescript-language-server"):
224 // err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
225 // if err != nil {
226 // return nil, err
227 // }
228 }
229
230 return &result, nil
231}
232
233func (c *Client) Close() error {
234 // Try to close all open files first
235 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
236 defer cancel()
237
238 // Attempt to close files but continue shutdown regardless
239 c.CloseAllFiles(ctx)
240
241 // Close stdin to signal the server
242 if err := c.stdin.Close(); err != nil {
243 return fmt.Errorf("failed to close stdin: %w", err)
244 }
245
246 // Use a channel to handle the Wait with timeout
247 done := make(chan error, 1)
248 go func() {
249 done <- c.Cmd.Wait()
250 }()
251
252 // Wait for process to exit with timeout
253 select {
254 case err := <-done:
255 return err
256 case <-time.After(2 * time.Second):
257 // If we timeout, try to kill the process
258 if err := c.Cmd.Process.Kill(); err != nil {
259 return fmt.Errorf("failed to kill process: %w", err)
260 }
261 return fmt.Errorf("process killed after timeout")
262 }
263}
264
265type ServerState int
266
267const (
268 StateStarting ServerState = iota
269 StateReady
270 StateError
271)
272
273func (c *Client) WaitForServerReady(ctx context.Context) error {
274 // TODO: wait for specific messages or poll workspace/symbol
275 time.Sleep(time.Second * 1)
276 return nil
277}
278
279type OpenFileInfo struct {
280 Version int32
281 URI protocol.DocumentUri
282}
283
284func (c *Client) OpenFile(ctx context.Context, filepath string) error {
285 uri := fmt.Sprintf("file://%s", filepath)
286
287 c.openFilesMu.Lock()
288 if _, exists := c.openFiles[uri]; exists {
289 c.openFilesMu.Unlock()
290 return nil // Already open
291 }
292 c.openFilesMu.Unlock()
293
294 // Skip files that do not exist or cannot be read
295 content, err := os.ReadFile(filepath)
296 if err != nil {
297 return fmt.Errorf("error reading file: %w", err)
298 }
299
300 params := protocol.DidOpenTextDocumentParams{
301 TextDocument: protocol.TextDocumentItem{
302 URI: protocol.DocumentUri(uri),
303 LanguageID: DetectLanguageID(uri),
304 Version: 1,
305 Text: string(content),
306 },
307 }
308
309 if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
310 return err
311 }
312
313 c.openFilesMu.Lock()
314 c.openFiles[uri] = &OpenFileInfo{
315 Version: 1,
316 URI: protocol.DocumentUri(uri),
317 }
318 c.openFilesMu.Unlock()
319
320 return nil
321}
322
323func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
324 uri := fmt.Sprintf("file://%s", filepath)
325
326 content, err := os.ReadFile(filepath)
327 if err != nil {
328 return fmt.Errorf("error reading file: %w", err)
329 }
330
331 c.openFilesMu.Lock()
332 fileInfo, isOpen := c.openFiles[uri]
333 if !isOpen {
334 c.openFilesMu.Unlock()
335 return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
336 }
337
338 // Increment version
339 fileInfo.Version++
340 version := fileInfo.Version
341 c.openFilesMu.Unlock()
342
343 params := protocol.DidChangeTextDocumentParams{
344 TextDocument: protocol.VersionedTextDocumentIdentifier{
345 TextDocumentIdentifier: protocol.TextDocumentIdentifier{
346 URI: protocol.DocumentUri(uri),
347 },
348 Version: version,
349 },
350 ContentChanges: []protocol.TextDocumentContentChangeEvent{
351 {
352 Value: protocol.TextDocumentContentChangeWholeDocument{
353 Text: string(content),
354 },
355 },
356 },
357 }
358
359 return c.Notify(ctx, "textDocument/didChange", params)
360}
361
362func (c *Client) CloseFile(ctx context.Context, filepath string) error {
363 cnf := config.Get()
364 uri := fmt.Sprintf("file://%s", filepath)
365
366 c.openFilesMu.Lock()
367 if _, exists := c.openFiles[uri]; !exists {
368 c.openFilesMu.Unlock()
369 return nil // Already closed
370 }
371 c.openFilesMu.Unlock()
372
373 params := protocol.DidCloseTextDocumentParams{
374 TextDocument: protocol.TextDocumentIdentifier{
375 URI: protocol.DocumentUri(uri),
376 },
377 }
378
379 if cnf.Debug {
380 logger.Debug("Closing file", "file", filepath)
381 }
382 if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
383 return err
384 }
385
386 c.openFilesMu.Lock()
387 delete(c.openFiles, uri)
388 c.openFilesMu.Unlock()
389
390 return nil
391}
392
393func (c *Client) IsFileOpen(filepath string) bool {
394 uri := fmt.Sprintf("file://%s", filepath)
395 c.openFilesMu.RLock()
396 defer c.openFilesMu.RUnlock()
397 _, exists := c.openFiles[uri]
398 return exists
399}
400
401// CloseAllFiles closes all currently open files
402func (c *Client) CloseAllFiles(ctx context.Context) {
403 cnf := config.Get()
404 c.openFilesMu.Lock()
405 filesToClose := make([]string, 0, len(c.openFiles))
406
407 // First collect all URIs that need to be closed
408 for uri := range c.openFiles {
409 // Convert URI back to file path by trimming "file://" prefix
410 filePath := strings.TrimPrefix(uri, "file://")
411 filesToClose = append(filesToClose, filePath)
412 }
413 c.openFilesMu.Unlock()
414
415 // Then close them all
416 for _, filePath := range filesToClose {
417 err := c.CloseFile(ctx, filePath)
418 if err != nil && cnf.Debug {
419 logger.Warn("Error closing file", "file", filePath, "error", err)
420 }
421 }
422
423 if cnf.Debug {
424 logger.Debug("Closed all files", "files", filesToClose)
425 }
426}
427
428func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
429 c.diagnosticsMu.RLock()
430 defer c.diagnosticsMu.RUnlock()
431
432 return c.diagnostics[uri]
433}
434
435func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
436 return c.diagnostics
437}