client.go

  1package lsp
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"strings"
 12	"sync"
 13	"sync/atomic"
 14	"time"
 15
 16	"github.com/kujtimiihoxha/termai/internal/config"
 17	"github.com/kujtimiihoxha/termai/internal/logging"
 18	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 19)
 20
 21var logger = logging.Get()
 22
 23type Client struct {
 24	Cmd    *exec.Cmd
 25	stdin  io.WriteCloser
 26	stdout *bufio.Reader
 27	stderr io.ReadCloser
 28
 29	// Request ID counter
 30	nextID atomic.Int32
 31
 32	// Response handlers
 33	handlers   map[int32]chan *Message
 34	handlersMu sync.RWMutex
 35
 36	// Server request handlers
 37	serverRequestHandlers map[string]ServerRequestHandler
 38	serverHandlersMu      sync.RWMutex
 39
 40	// Notification handlers
 41	notificationHandlers map[string]NotificationHandler
 42	notificationMu       sync.RWMutex
 43
 44	// Diagnostic cache
 45	diagnostics   map[protocol.DocumentUri][]protocol.Diagnostic
 46	diagnosticsMu sync.RWMutex
 47
 48	// Files are currently opened by the LSP
 49	openFiles   map[string]*OpenFileInfo
 50	openFilesMu sync.RWMutex
 51}
 52
 53func NewClient(command string, args ...string) (*Client, error) {
 54	cmd := exec.Command(command, args...)
 55	// Copy env
 56	cmd.Env = os.Environ()
 57
 58	stdin, err := cmd.StdinPipe()
 59	if err != nil {
 60		return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
 61	}
 62
 63	stdout, err := cmd.StdoutPipe()
 64	if err != nil {
 65		return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
 66	}
 67
 68	stderr, err := cmd.StderrPipe()
 69	if err != nil {
 70		return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
 71	}
 72
 73	client := &Client{
 74		Cmd:                   cmd,
 75		stdin:                 stdin,
 76		stdout:                bufio.NewReader(stdout),
 77		stderr:                stderr,
 78		handlers:              make(map[int32]chan *Message),
 79		notificationHandlers:  make(map[string]NotificationHandler),
 80		serverRequestHandlers: make(map[string]ServerRequestHandler),
 81		diagnostics:           make(map[protocol.DocumentUri][]protocol.Diagnostic),
 82		openFiles:             make(map[string]*OpenFileInfo),
 83	}
 84
 85	// Start the LSP server process
 86	if err := cmd.Start(); err != nil {
 87		return nil, fmt.Errorf("failed to start LSP server: %w", err)
 88	}
 89
 90	// Handle stderr in a separate goroutine
 91	go func() {
 92		scanner := bufio.NewScanner(stderr)
 93		for scanner.Scan() {
 94			fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
 95		}
 96		if err := scanner.Err(); err != nil {
 97			fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
 98		}
 99	}()
100
101	// Start message handling loop
102	go client.handleMessages()
103
104	return client, nil
105}
106
107func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
108	c.notificationMu.Lock()
109	defer c.notificationMu.Unlock()
110	c.notificationHandlers[method] = handler
111}
112
113func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
114	c.serverHandlersMu.Lock()
115	defer c.serverHandlersMu.Unlock()
116	c.serverRequestHandlers[method] = handler
117}
118
119func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
120	initParams := &protocol.InitializeParams{
121		WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
122			WorkspaceFolders: []protocol.WorkspaceFolder{
123				{
124					URI:  protocol.URI("file://" + workspaceDir),
125					Name: workspaceDir,
126				},
127			},
128		},
129
130		XInitializeParams: protocol.XInitializeParams{
131			ProcessID: int32(os.Getpid()),
132			ClientInfo: &protocol.ClientInfo{
133				Name:    "mcp-language-server",
134				Version: "0.1.0",
135			},
136			RootPath: workspaceDir,
137			RootURI:  protocol.DocumentUri("file://" + workspaceDir),
138			Capabilities: protocol.ClientCapabilities{
139				Workspace: protocol.WorkspaceClientCapabilities{
140					Configuration: true,
141					DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
142						DynamicRegistration: true,
143					},
144					DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
145						DynamicRegistration:    true,
146						RelativePatternSupport: true,
147					},
148				},
149				TextDocument: protocol.TextDocumentClientCapabilities{
150					Synchronization: &protocol.TextDocumentSyncClientCapabilities{
151						DynamicRegistration: true,
152						DidSave:             true,
153					},
154					Completion: protocol.CompletionClientCapabilities{
155						CompletionItem: protocol.ClientCompletionItemOptions{},
156					},
157					CodeLens: &protocol.CodeLensClientCapabilities{
158						DynamicRegistration: true,
159					},
160					DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
161					CodeAction: protocol.CodeActionClientCapabilities{
162						CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
163							CodeActionKind: protocol.ClientCodeActionKindOptions{
164								ValueSet: []protocol.CodeActionKind{},
165							},
166						},
167					},
168					PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
169						VersionSupport: true,
170					},
171					SemanticTokens: protocol.SemanticTokensClientCapabilities{
172						Requests: protocol.ClientSemanticTokensRequestOptions{
173							Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
174							Full:  &protocol.Or_ClientSemanticTokensRequestOptions_full{},
175						},
176						TokenTypes:     []string{},
177						TokenModifiers: []string{},
178						Formats:        []protocol.TokenFormat{},
179					},
180				},
181				Window: protocol.WindowClientCapabilities{},
182			},
183			InitializationOptions: map[string]any{
184				"codelenses": map[string]bool{
185					"generate":           true,
186					"regenerate_cgo":     true,
187					"test":               true,
188					"tidy":               true,
189					"upgrade_dependency": true,
190					"vendor":             true,
191					"vulncheck":          false,
192				},
193			},
194		},
195	}
196
197	var result protocol.InitializeResult
198	if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
199		return nil, fmt.Errorf("initialize failed: %w", err)
200	}
201
202	if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
203		return nil, fmt.Errorf("initialized notification failed: %w", err)
204	}
205
206	// Register handlers
207	c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
208	c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
209	c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
210	c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
211	c.RegisterNotificationHandler("textDocument/publishDiagnostics",
212		func(params json.RawMessage) { HandleDiagnostics(c, params) })
213
214	// Notify the LSP server
215	err := c.Initialized(ctx, protocol.InitializedParams{})
216	if err != nil {
217		return nil, fmt.Errorf("initialization failed: %w", err)
218	}
219
220	// LSP sepecific Initialization
221	path := strings.ToLower(c.Cmd.Path)
222	switch {
223	case strings.Contains(path, "typescript-language-server"):
224		// err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
225		// if err != nil {
226		// 	return nil, err
227		// }
228	}
229
230	return &result, nil
231}
232
233func (c *Client) Close() error {
234	// Try to close all open files first
235	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
236	defer cancel()
237
238	// Attempt to close files but continue shutdown regardless
239	c.CloseAllFiles(ctx)
240
241	// Close stdin to signal the server
242	if err := c.stdin.Close(); err != nil {
243		return fmt.Errorf("failed to close stdin: %w", err)
244	}
245
246	// Use a channel to handle the Wait with timeout
247	done := make(chan error, 1)
248	go func() {
249		done <- c.Cmd.Wait()
250	}()
251
252	// Wait for process to exit with timeout
253	select {
254	case err := <-done:
255		return err
256	case <-time.After(2 * time.Second):
257		// If we timeout, try to kill the process
258		if err := c.Cmd.Process.Kill(); err != nil {
259			return fmt.Errorf("failed to kill process: %w", err)
260		}
261		return fmt.Errorf("process killed after timeout")
262	}
263}
264
265type ServerState int
266
267const (
268	StateStarting ServerState = iota
269	StateReady
270	StateError
271)
272
273func (c *Client) WaitForServerReady(ctx context.Context) error {
274	// TODO: wait for specific messages or poll workspace/symbol
275	time.Sleep(time.Second * 1)
276	return nil
277}
278
279type OpenFileInfo struct {
280	Version int32
281	URI     protocol.DocumentUri
282}
283
284func (c *Client) OpenFile(ctx context.Context, filepath string) error {
285	uri := fmt.Sprintf("file://%s", filepath)
286
287	c.openFilesMu.Lock()
288	if _, exists := c.openFiles[uri]; exists {
289		c.openFilesMu.Unlock()
290		return nil // Already open
291	}
292	c.openFilesMu.Unlock()
293
294	// Skip files that do not exist or cannot be read
295	content, err := os.ReadFile(filepath)
296	if err != nil {
297		return fmt.Errorf("error reading file: %w", err)
298	}
299
300	params := protocol.DidOpenTextDocumentParams{
301		TextDocument: protocol.TextDocumentItem{
302			URI:        protocol.DocumentUri(uri),
303			LanguageID: DetectLanguageID(uri),
304			Version:    1,
305			Text:       string(content),
306		},
307	}
308
309	if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
310		return err
311	}
312
313	c.openFilesMu.Lock()
314	c.openFiles[uri] = &OpenFileInfo{
315		Version: 1,
316		URI:     protocol.DocumentUri(uri),
317	}
318	c.openFilesMu.Unlock()
319
320	return nil
321}
322
323func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
324	uri := fmt.Sprintf("file://%s", filepath)
325
326	content, err := os.ReadFile(filepath)
327	if err != nil {
328		return fmt.Errorf("error reading file: %w", err)
329	}
330
331	c.openFilesMu.Lock()
332	fileInfo, isOpen := c.openFiles[uri]
333	if !isOpen {
334		c.openFilesMu.Unlock()
335		return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
336	}
337
338	// Increment version
339	fileInfo.Version++
340	version := fileInfo.Version
341	c.openFilesMu.Unlock()
342
343	params := protocol.DidChangeTextDocumentParams{
344		TextDocument: protocol.VersionedTextDocumentIdentifier{
345			TextDocumentIdentifier: protocol.TextDocumentIdentifier{
346				URI: protocol.DocumentUri(uri),
347			},
348			Version: version,
349		},
350		ContentChanges: []protocol.TextDocumentContentChangeEvent{
351			{
352				Value: protocol.TextDocumentContentChangeWholeDocument{
353					Text: string(content),
354				},
355			},
356		},
357	}
358
359	return c.Notify(ctx, "textDocument/didChange", params)
360}
361
362func (c *Client) CloseFile(ctx context.Context, filepath string) error {
363	cnf := config.Get()
364	uri := fmt.Sprintf("file://%s", filepath)
365
366	c.openFilesMu.Lock()
367	if _, exists := c.openFiles[uri]; !exists {
368		c.openFilesMu.Unlock()
369		return nil // Already closed
370	}
371	c.openFilesMu.Unlock()
372
373	params := protocol.DidCloseTextDocumentParams{
374		TextDocument: protocol.TextDocumentIdentifier{
375			URI: protocol.DocumentUri(uri),
376		},
377	}
378
379	if cnf.Debug {
380		logger.Debug("Closing file", "file", filepath)
381	}
382	if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
383		return err
384	}
385
386	c.openFilesMu.Lock()
387	delete(c.openFiles, uri)
388	c.openFilesMu.Unlock()
389
390	return nil
391}
392
393func (c *Client) IsFileOpen(filepath string) bool {
394	uri := fmt.Sprintf("file://%s", filepath)
395	c.openFilesMu.RLock()
396	defer c.openFilesMu.RUnlock()
397	_, exists := c.openFiles[uri]
398	return exists
399}
400
401// CloseAllFiles closes all currently open files
402func (c *Client) CloseAllFiles(ctx context.Context) {
403	cnf := config.Get()
404	c.openFilesMu.Lock()
405	filesToClose := make([]string, 0, len(c.openFiles))
406
407	// First collect all URIs that need to be closed
408	for uri := range c.openFiles {
409		// Convert URI back to file path by trimming "file://" prefix
410		filePath := strings.TrimPrefix(uri, "file://")
411		filesToClose = append(filesToClose, filePath)
412	}
413	c.openFilesMu.Unlock()
414
415	// Then close them all
416	for _, filePath := range filesToClose {
417		err := c.CloseFile(ctx, filePath)
418		if err != nil && cnf.Debug {
419			logger.Warn("Error closing file", "file", filePath, "error", err)
420		}
421	}
422
423	if cnf.Debug {
424		logger.Debug("Closed all files", "files", filesToClose)
425	}
426}
427
428func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
429	c.diagnosticsMu.RLock()
430	defer c.diagnosticsMu.RUnlock()
431
432	return c.diagnostics[uri]
433}
434
435func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
436	return c.diagnostics
437}