client.go

  1package lsp
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"strings"
 12	"sync"
 13	"sync/atomic"
 14	"time"
 15
 16	"github.com/kujtimiihoxha/termai/internal/config"
 17	"github.com/kujtimiihoxha/termai/internal/logging"
 18	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 19)
 20
 21type Client struct {
 22	Cmd    *exec.Cmd
 23	stdin  io.WriteCloser
 24	stdout *bufio.Reader
 25	stderr io.ReadCloser
 26
 27	// Request ID counter
 28	nextID atomic.Int32
 29
 30	// Response handlers
 31	handlers   map[int32]chan *Message
 32	handlersMu sync.RWMutex
 33
 34	// Server request handlers
 35	serverRequestHandlers map[string]ServerRequestHandler
 36	serverHandlersMu      sync.RWMutex
 37
 38	// Notification handlers
 39	notificationHandlers map[string]NotificationHandler
 40	notificationMu       sync.RWMutex
 41
 42	// Diagnostic cache
 43	diagnostics   map[protocol.DocumentUri][]protocol.Diagnostic
 44	diagnosticsMu sync.RWMutex
 45
 46	// Files are currently opened by the LSP
 47	openFiles   map[string]*OpenFileInfo
 48	openFilesMu sync.RWMutex
 49}
 50
 51func NewClient(ctx context.Context, command string, args ...string) (*Client, error) {
 52	cmd := exec.CommandContext(ctx, command, args...)
 53	// Copy env
 54	cmd.Env = os.Environ()
 55
 56	stdin, err := cmd.StdinPipe()
 57	if err != nil {
 58		return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
 59	}
 60
 61	stdout, err := cmd.StdoutPipe()
 62	if err != nil {
 63		return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
 64	}
 65
 66	stderr, err := cmd.StderrPipe()
 67	if err != nil {
 68		return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
 69	}
 70
 71	client := &Client{
 72		Cmd:                   cmd,
 73		stdin:                 stdin,
 74		stdout:                bufio.NewReader(stdout),
 75		stderr:                stderr,
 76		handlers:              make(map[int32]chan *Message),
 77		notificationHandlers:  make(map[string]NotificationHandler),
 78		serverRequestHandlers: make(map[string]ServerRequestHandler),
 79		diagnostics:           make(map[protocol.DocumentUri][]protocol.Diagnostic),
 80		openFiles:             make(map[string]*OpenFileInfo),
 81	}
 82
 83	// Start the LSP server process
 84	if err := cmd.Start(); err != nil {
 85		return nil, fmt.Errorf("failed to start LSP server: %w", err)
 86	}
 87
 88	// Handle stderr in a separate goroutine
 89	go func() {
 90		scanner := bufio.NewScanner(stderr)
 91		for scanner.Scan() {
 92			fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
 93		}
 94		if err := scanner.Err(); err != nil {
 95			fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
 96		}
 97	}()
 98
 99	// Start message handling loop
100	go func() {
101		defer logging.RecoverPanic("LSP-message-handler", func() {
102			logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired")
103		})
104		client.handleMessages()
105	}()
106
107	return client, nil
108}
109
110func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
111	c.notificationMu.Lock()
112	defer c.notificationMu.Unlock()
113	c.notificationHandlers[method] = handler
114}
115
116func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
117	c.serverHandlersMu.Lock()
118	defer c.serverHandlersMu.Unlock()
119	c.serverRequestHandlers[method] = handler
120}
121
122func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
123	initParams := &protocol.InitializeParams{
124		WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
125			WorkspaceFolders: []protocol.WorkspaceFolder{
126				{
127					URI:  protocol.URI("file://" + workspaceDir),
128					Name: workspaceDir,
129				},
130			},
131		},
132
133		XInitializeParams: protocol.XInitializeParams{
134			ProcessID: int32(os.Getpid()),
135			ClientInfo: &protocol.ClientInfo{
136				Name:    "mcp-language-server",
137				Version: "0.1.0",
138			},
139			RootPath: workspaceDir,
140			RootURI:  protocol.DocumentUri("file://" + workspaceDir),
141			Capabilities: protocol.ClientCapabilities{
142				Workspace: protocol.WorkspaceClientCapabilities{
143					Configuration: true,
144					DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
145						DynamicRegistration: true,
146					},
147					DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
148						DynamicRegistration:    true,
149						RelativePatternSupport: true,
150					},
151				},
152				TextDocument: protocol.TextDocumentClientCapabilities{
153					Synchronization: &protocol.TextDocumentSyncClientCapabilities{
154						DynamicRegistration: true,
155						DidSave:             true,
156					},
157					Completion: protocol.CompletionClientCapabilities{
158						CompletionItem: protocol.ClientCompletionItemOptions{},
159					},
160					CodeLens: &protocol.CodeLensClientCapabilities{
161						DynamicRegistration: true,
162					},
163					DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
164					CodeAction: protocol.CodeActionClientCapabilities{
165						CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
166							CodeActionKind: protocol.ClientCodeActionKindOptions{
167								ValueSet: []protocol.CodeActionKind{},
168							},
169						},
170					},
171					PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
172						VersionSupport: true,
173					},
174					SemanticTokens: protocol.SemanticTokensClientCapabilities{
175						Requests: protocol.ClientSemanticTokensRequestOptions{
176							Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
177							Full:  &protocol.Or_ClientSemanticTokensRequestOptions_full{},
178						},
179						TokenTypes:     []string{},
180						TokenModifiers: []string{},
181						Formats:        []protocol.TokenFormat{},
182					},
183				},
184				Window: protocol.WindowClientCapabilities{},
185			},
186			InitializationOptions: map[string]any{
187				"codelenses": map[string]bool{
188					"generate":           true,
189					"regenerate_cgo":     true,
190					"test":               true,
191					"tidy":               true,
192					"upgrade_dependency": true,
193					"vendor":             true,
194					"vulncheck":          false,
195				},
196			},
197		},
198	}
199
200	var result protocol.InitializeResult
201	if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
202		return nil, fmt.Errorf("initialize failed: %w", err)
203	}
204
205	if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
206		return nil, fmt.Errorf("initialized notification failed: %w", err)
207	}
208
209	// Register handlers
210	c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
211	c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
212	c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
213	c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
214	c.RegisterNotificationHandler("textDocument/publishDiagnostics",
215		func(params json.RawMessage) { HandleDiagnostics(c, params) })
216
217	// Notify the LSP server
218	err := c.Initialized(ctx, protocol.InitializedParams{})
219	if err != nil {
220		return nil, fmt.Errorf("initialization failed: %w", err)
221	}
222
223	// LSP sepecific Initialization
224	path := strings.ToLower(c.Cmd.Path)
225	switch {
226	case strings.Contains(path, "typescript-language-server"):
227		// err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
228		// if err != nil {
229		// 	return nil, err
230		// }
231	}
232
233	return &result, nil
234}
235
236func (c *Client) Close() error {
237	// Try to close all open files first
238	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
239	defer cancel()
240
241	// Attempt to close files but continue shutdown regardless
242	c.CloseAllFiles(ctx)
243
244	// Close stdin to signal the server
245	if err := c.stdin.Close(); err != nil {
246		return fmt.Errorf("failed to close stdin: %w", err)
247	}
248
249	// Use a channel to handle the Wait with timeout
250	done := make(chan error, 1)
251	go func() {
252		done <- c.Cmd.Wait()
253	}()
254
255	// Wait for process to exit with timeout
256	select {
257	case err := <-done:
258		return err
259	case <-time.After(2 * time.Second):
260		// If we timeout, try to kill the process
261		if err := c.Cmd.Process.Kill(); err != nil {
262			return fmt.Errorf("failed to kill process: %w", err)
263		}
264		return fmt.Errorf("process killed after timeout")
265	}
266}
267
268type ServerState int
269
270const (
271	StateStarting ServerState = iota
272	StateReady
273	StateError
274)
275
276func (c *Client) WaitForServerReady(ctx context.Context) error {
277	// TODO: wait for specific messages or poll workspace/symbol
278	time.Sleep(time.Second * 1)
279	return nil
280}
281
282type OpenFileInfo struct {
283	Version int32
284	URI     protocol.DocumentUri
285}
286
287func (c *Client) OpenFile(ctx context.Context, filepath string) error {
288	uri := fmt.Sprintf("file://%s", filepath)
289
290	c.openFilesMu.Lock()
291	if _, exists := c.openFiles[uri]; exists {
292		c.openFilesMu.Unlock()
293		return nil // Already open
294	}
295	c.openFilesMu.Unlock()
296
297	// Skip files that do not exist or cannot be read
298	content, err := os.ReadFile(filepath)
299	if err != nil {
300		return fmt.Errorf("error reading file: %w", err)
301	}
302
303	params := protocol.DidOpenTextDocumentParams{
304		TextDocument: protocol.TextDocumentItem{
305			URI:        protocol.DocumentUri(uri),
306			LanguageID: DetectLanguageID(uri),
307			Version:    1,
308			Text:       string(content),
309		},
310	}
311
312	if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
313		return err
314	}
315
316	c.openFilesMu.Lock()
317	c.openFiles[uri] = &OpenFileInfo{
318		Version: 1,
319		URI:     protocol.DocumentUri(uri),
320	}
321	c.openFilesMu.Unlock()
322
323	return nil
324}
325
326func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
327	uri := fmt.Sprintf("file://%s", filepath)
328
329	content, err := os.ReadFile(filepath)
330	if err != nil {
331		return fmt.Errorf("error reading file: %w", err)
332	}
333
334	c.openFilesMu.Lock()
335	fileInfo, isOpen := c.openFiles[uri]
336	if !isOpen {
337		c.openFilesMu.Unlock()
338		return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
339	}
340
341	// Increment version
342	fileInfo.Version++
343	version := fileInfo.Version
344	c.openFilesMu.Unlock()
345
346	params := protocol.DidChangeTextDocumentParams{
347		TextDocument: protocol.VersionedTextDocumentIdentifier{
348			TextDocumentIdentifier: protocol.TextDocumentIdentifier{
349				URI: protocol.DocumentUri(uri),
350			},
351			Version: version,
352		},
353		ContentChanges: []protocol.TextDocumentContentChangeEvent{
354			{
355				Value: protocol.TextDocumentContentChangeWholeDocument{
356					Text: string(content),
357				},
358			},
359		},
360	}
361
362	return c.Notify(ctx, "textDocument/didChange", params)
363}
364
365func (c *Client) CloseFile(ctx context.Context, filepath string) error {
366	cnf := config.Get()
367	uri := fmt.Sprintf("file://%s", filepath)
368
369	c.openFilesMu.Lock()
370	if _, exists := c.openFiles[uri]; !exists {
371		c.openFilesMu.Unlock()
372		return nil // Already closed
373	}
374	c.openFilesMu.Unlock()
375
376	params := protocol.DidCloseTextDocumentParams{
377		TextDocument: protocol.TextDocumentIdentifier{
378			URI: protocol.DocumentUri(uri),
379		},
380	}
381
382	if cnf.DebugLSP {
383		logging.Debug("Closing file", "file", filepath)
384	}
385	if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
386		return err
387	}
388
389	c.openFilesMu.Lock()
390	delete(c.openFiles, uri)
391	c.openFilesMu.Unlock()
392
393	return nil
394}
395
396func (c *Client) IsFileOpen(filepath string) bool {
397	uri := fmt.Sprintf("file://%s", filepath)
398	c.openFilesMu.RLock()
399	defer c.openFilesMu.RUnlock()
400	_, exists := c.openFiles[uri]
401	return exists
402}
403
404// CloseAllFiles closes all currently open files
405func (c *Client) CloseAllFiles(ctx context.Context) {
406	cnf := config.Get()
407	c.openFilesMu.Lock()
408	filesToClose := make([]string, 0, len(c.openFiles))
409
410	// First collect all URIs that need to be closed
411	for uri := range c.openFiles {
412		// Convert URI back to file path by trimming "file://" prefix
413		filePath := strings.TrimPrefix(uri, "file://")
414		filesToClose = append(filesToClose, filePath)
415	}
416	c.openFilesMu.Unlock()
417
418	// Then close them all
419	for _, filePath := range filesToClose {
420		err := c.CloseFile(ctx, filePath)
421		if err != nil && cnf.DebugLSP {
422			logging.Warn("Error closing file", "file", filePath, "error", err)
423		}
424	}
425
426	if cnf.DebugLSP {
427		logging.Debug("Closed all files", "files", filesToClose)
428	}
429}
430
431func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
432	c.diagnosticsMu.RLock()
433	defer c.diagnosticsMu.RUnlock()
434
435	return c.diagnostics[uri]
436}
437
438func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
439	return c.diagnostics
440}