1package lsp
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "strings"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "github.com/kujtimiihoxha/termai/internal/config"
17 "github.com/kujtimiihoxha/termai/internal/logging"
18 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
19)
20
21type Client struct {
22 Cmd *exec.Cmd
23 stdin io.WriteCloser
24 stdout *bufio.Reader
25 stderr io.ReadCloser
26
27 // Request ID counter
28 nextID atomic.Int32
29
30 // Response handlers
31 handlers map[int32]chan *Message
32 handlersMu sync.RWMutex
33
34 // Server request handlers
35 serverRequestHandlers map[string]ServerRequestHandler
36 serverHandlersMu sync.RWMutex
37
38 // Notification handlers
39 notificationHandlers map[string]NotificationHandler
40 notificationMu sync.RWMutex
41
42 // Diagnostic cache
43 diagnostics map[protocol.DocumentUri][]protocol.Diagnostic
44 diagnosticsMu sync.RWMutex
45
46 // Files are currently opened by the LSP
47 openFiles map[string]*OpenFileInfo
48 openFilesMu sync.RWMutex
49}
50
51func NewClient(command string, args ...string) (*Client, error) {
52 cmd := exec.Command(command, args...)
53 // Copy env
54 cmd.Env = os.Environ()
55
56 stdin, err := cmd.StdinPipe()
57 if err != nil {
58 return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
59 }
60
61 stdout, err := cmd.StdoutPipe()
62 if err != nil {
63 return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
64 }
65
66 stderr, err := cmd.StderrPipe()
67 if err != nil {
68 return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
69 }
70
71 client := &Client{
72 Cmd: cmd,
73 stdin: stdin,
74 stdout: bufio.NewReader(stdout),
75 stderr: stderr,
76 handlers: make(map[int32]chan *Message),
77 notificationHandlers: make(map[string]NotificationHandler),
78 serverRequestHandlers: make(map[string]ServerRequestHandler),
79 diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),
80 openFiles: make(map[string]*OpenFileInfo),
81 }
82
83 // Start the LSP server process
84 if err := cmd.Start(); err != nil {
85 return nil, fmt.Errorf("failed to start LSP server: %w", err)
86 }
87
88 // Handle stderr in a separate goroutine
89 go func() {
90 scanner := bufio.NewScanner(stderr)
91 for scanner.Scan() {
92 fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
93 }
94 if err := scanner.Err(); err != nil {
95 fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
96 }
97 }()
98
99 // Start message handling loop
100 go client.handleMessages()
101
102 return client, nil
103}
104
105func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
106 c.notificationMu.Lock()
107 defer c.notificationMu.Unlock()
108 c.notificationHandlers[method] = handler
109}
110
111func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
112 c.serverHandlersMu.Lock()
113 defer c.serverHandlersMu.Unlock()
114 c.serverRequestHandlers[method] = handler
115}
116
117func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
118 initParams := &protocol.InitializeParams{
119 WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
120 WorkspaceFolders: []protocol.WorkspaceFolder{
121 {
122 URI: protocol.URI("file://" + workspaceDir),
123 Name: workspaceDir,
124 },
125 },
126 },
127
128 XInitializeParams: protocol.XInitializeParams{
129 ProcessID: int32(os.Getpid()),
130 ClientInfo: &protocol.ClientInfo{
131 Name: "mcp-language-server",
132 Version: "0.1.0",
133 },
134 RootPath: workspaceDir,
135 RootURI: protocol.DocumentUri("file://" + workspaceDir),
136 Capabilities: protocol.ClientCapabilities{
137 Workspace: protocol.WorkspaceClientCapabilities{
138 Configuration: true,
139 DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
140 DynamicRegistration: true,
141 },
142 DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
143 DynamicRegistration: true,
144 RelativePatternSupport: true,
145 },
146 },
147 TextDocument: protocol.TextDocumentClientCapabilities{
148 Synchronization: &protocol.TextDocumentSyncClientCapabilities{
149 DynamicRegistration: true,
150 DidSave: true,
151 },
152 Completion: protocol.CompletionClientCapabilities{
153 CompletionItem: protocol.ClientCompletionItemOptions{},
154 },
155 CodeLens: &protocol.CodeLensClientCapabilities{
156 DynamicRegistration: true,
157 },
158 DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
159 CodeAction: protocol.CodeActionClientCapabilities{
160 CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
161 CodeActionKind: protocol.ClientCodeActionKindOptions{
162 ValueSet: []protocol.CodeActionKind{},
163 },
164 },
165 },
166 PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
167 VersionSupport: true,
168 },
169 SemanticTokens: protocol.SemanticTokensClientCapabilities{
170 Requests: protocol.ClientSemanticTokensRequestOptions{
171 Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
172 Full: &protocol.Or_ClientSemanticTokensRequestOptions_full{},
173 },
174 TokenTypes: []string{},
175 TokenModifiers: []string{},
176 Formats: []protocol.TokenFormat{},
177 },
178 },
179 Window: protocol.WindowClientCapabilities{},
180 },
181 InitializationOptions: map[string]any{
182 "codelenses": map[string]bool{
183 "generate": true,
184 "regenerate_cgo": true,
185 "test": true,
186 "tidy": true,
187 "upgrade_dependency": true,
188 "vendor": true,
189 "vulncheck": false,
190 },
191 },
192 },
193 }
194
195 var result protocol.InitializeResult
196 if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
197 return nil, fmt.Errorf("initialize failed: %w", err)
198 }
199
200 if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
201 return nil, fmt.Errorf("initialized notification failed: %w", err)
202 }
203
204 // Register handlers
205 c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
206 c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
207 c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
208 c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
209 c.RegisterNotificationHandler("textDocument/publishDiagnostics",
210 func(params json.RawMessage) { HandleDiagnostics(c, params) })
211
212 // Notify the LSP server
213 err := c.Initialized(ctx, protocol.InitializedParams{})
214 if err != nil {
215 return nil, fmt.Errorf("initialization failed: %w", err)
216 }
217
218 // LSP sepecific Initialization
219 path := strings.ToLower(c.Cmd.Path)
220 switch {
221 case strings.Contains(path, "typescript-language-server"):
222 // err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
223 // if err != nil {
224 // return nil, err
225 // }
226 }
227
228 return &result, nil
229}
230
231func (c *Client) Close() error {
232 // Try to close all open files first
233 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
234 defer cancel()
235
236 // Attempt to close files but continue shutdown regardless
237 c.CloseAllFiles(ctx)
238
239 // Close stdin to signal the server
240 if err := c.stdin.Close(); err != nil {
241 return fmt.Errorf("failed to close stdin: %w", err)
242 }
243
244 // Use a channel to handle the Wait with timeout
245 done := make(chan error, 1)
246 go func() {
247 done <- c.Cmd.Wait()
248 }()
249
250 // Wait for process to exit with timeout
251 select {
252 case err := <-done:
253 return err
254 case <-time.After(2 * time.Second):
255 // If we timeout, try to kill the process
256 if err := c.Cmd.Process.Kill(); err != nil {
257 return fmt.Errorf("failed to kill process: %w", err)
258 }
259 return fmt.Errorf("process killed after timeout")
260 }
261}
262
263type ServerState int
264
265const (
266 StateStarting ServerState = iota
267 StateReady
268 StateError
269)
270
271func (c *Client) WaitForServerReady(ctx context.Context) error {
272 // TODO: wait for specific messages or poll workspace/symbol
273 time.Sleep(time.Second * 1)
274 return nil
275}
276
277type OpenFileInfo struct {
278 Version int32
279 URI protocol.DocumentUri
280}
281
282func (c *Client) OpenFile(ctx context.Context, filepath string) error {
283 uri := fmt.Sprintf("file://%s", filepath)
284
285 c.openFilesMu.Lock()
286 if _, exists := c.openFiles[uri]; exists {
287 c.openFilesMu.Unlock()
288 return nil // Already open
289 }
290 c.openFilesMu.Unlock()
291
292 // Skip files that do not exist or cannot be read
293 content, err := os.ReadFile(filepath)
294 if err != nil {
295 return fmt.Errorf("error reading file: %w", err)
296 }
297
298 params := protocol.DidOpenTextDocumentParams{
299 TextDocument: protocol.TextDocumentItem{
300 URI: protocol.DocumentUri(uri),
301 LanguageID: DetectLanguageID(uri),
302 Version: 1,
303 Text: string(content),
304 },
305 }
306
307 if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
308 return err
309 }
310
311 c.openFilesMu.Lock()
312 c.openFiles[uri] = &OpenFileInfo{
313 Version: 1,
314 URI: protocol.DocumentUri(uri),
315 }
316 c.openFilesMu.Unlock()
317
318 return nil
319}
320
321func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
322 uri := fmt.Sprintf("file://%s", filepath)
323
324 content, err := os.ReadFile(filepath)
325 if err != nil {
326 return fmt.Errorf("error reading file: %w", err)
327 }
328
329 c.openFilesMu.Lock()
330 fileInfo, isOpen := c.openFiles[uri]
331 if !isOpen {
332 c.openFilesMu.Unlock()
333 return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
334 }
335
336 // Increment version
337 fileInfo.Version++
338 version := fileInfo.Version
339 c.openFilesMu.Unlock()
340
341 params := protocol.DidChangeTextDocumentParams{
342 TextDocument: protocol.VersionedTextDocumentIdentifier{
343 TextDocumentIdentifier: protocol.TextDocumentIdentifier{
344 URI: protocol.DocumentUri(uri),
345 },
346 Version: version,
347 },
348 ContentChanges: []protocol.TextDocumentContentChangeEvent{
349 {
350 Value: protocol.TextDocumentContentChangeWholeDocument{
351 Text: string(content),
352 },
353 },
354 },
355 }
356
357 return c.Notify(ctx, "textDocument/didChange", params)
358}
359
360func (c *Client) CloseFile(ctx context.Context, filepath string) error {
361 cnf := config.Get()
362 uri := fmt.Sprintf("file://%s", filepath)
363
364 c.openFilesMu.Lock()
365 if _, exists := c.openFiles[uri]; !exists {
366 c.openFilesMu.Unlock()
367 return nil // Already closed
368 }
369 c.openFilesMu.Unlock()
370
371 params := protocol.DidCloseTextDocumentParams{
372 TextDocument: protocol.TextDocumentIdentifier{
373 URI: protocol.DocumentUri(uri),
374 },
375 }
376
377 if cnf.Debug {
378 logging.Debug("Closing file", "file", filepath)
379 }
380 if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
381 return err
382 }
383
384 c.openFilesMu.Lock()
385 delete(c.openFiles, uri)
386 c.openFilesMu.Unlock()
387
388 return nil
389}
390
391func (c *Client) IsFileOpen(filepath string) bool {
392 uri := fmt.Sprintf("file://%s", filepath)
393 c.openFilesMu.RLock()
394 defer c.openFilesMu.RUnlock()
395 _, exists := c.openFiles[uri]
396 return exists
397}
398
399// CloseAllFiles closes all currently open files
400func (c *Client) CloseAllFiles(ctx context.Context) {
401 cnf := config.Get()
402 c.openFilesMu.Lock()
403 filesToClose := make([]string, 0, len(c.openFiles))
404
405 // First collect all URIs that need to be closed
406 for uri := range c.openFiles {
407 // Convert URI back to file path by trimming "file://" prefix
408 filePath := strings.TrimPrefix(uri, "file://")
409 filesToClose = append(filesToClose, filePath)
410 }
411 c.openFilesMu.Unlock()
412
413 // Then close them all
414 for _, filePath := range filesToClose {
415 err := c.CloseFile(ctx, filePath)
416 if err != nil && cnf.Debug {
417 logging.Warn("Error closing file", "file", filePath, "error", err)
418 }
419 }
420
421 if cnf.Debug {
422 logging.Debug("Closed all files", "files", filesToClose)
423 }
424}
425
426func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
427 c.diagnosticsMu.RLock()
428 defer c.diagnosticsMu.RUnlock()
429
430 return c.diagnostics[uri]
431}
432
433func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
434 return c.diagnostics
435}