client.go

  1package lsp
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"log"
 10	"os"
 11	"os/exec"
 12	"strings"
 13	"sync"
 14	"sync/atomic"
 15	"time"
 16
 17	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 18)
 19
 20type Client struct {
 21	Cmd    *exec.Cmd
 22	stdin  io.WriteCloser
 23	stdout *bufio.Reader
 24	stderr io.ReadCloser
 25
 26	// Request ID counter
 27	nextID atomic.Int32
 28
 29	// Response handlers
 30	handlers   map[int32]chan *Message
 31	handlersMu sync.RWMutex
 32
 33	// Server request handlers
 34	serverRequestHandlers map[string]ServerRequestHandler
 35	serverHandlersMu      sync.RWMutex
 36
 37	// Notification handlers
 38	notificationHandlers map[string]NotificationHandler
 39	notificationMu       sync.RWMutex
 40
 41	// Diagnostic cache
 42	diagnostics   map[protocol.DocumentUri][]protocol.Diagnostic
 43	diagnosticsMu sync.RWMutex
 44
 45	// Files are currently opened by the LSP
 46	openFiles   map[string]*OpenFileInfo
 47	openFilesMu sync.RWMutex
 48}
 49
 50func NewClient(command string, args ...string) (*Client, error) {
 51	cmd := exec.Command(command, args...)
 52	// Copy env
 53	cmd.Env = os.Environ()
 54
 55	stdin, err := cmd.StdinPipe()
 56	if err != nil {
 57		return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
 58	}
 59
 60	stdout, err := cmd.StdoutPipe()
 61	if err != nil {
 62		return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
 63	}
 64
 65	stderr, err := cmd.StderrPipe()
 66	if err != nil {
 67		return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
 68	}
 69
 70	client := &Client{
 71		Cmd:                   cmd,
 72		stdin:                 stdin,
 73		stdout:                bufio.NewReader(stdout),
 74		stderr:                stderr,
 75		handlers:              make(map[int32]chan *Message),
 76		notificationHandlers:  make(map[string]NotificationHandler),
 77		serverRequestHandlers: make(map[string]ServerRequestHandler),
 78		diagnostics:           make(map[protocol.DocumentUri][]protocol.Diagnostic),
 79		openFiles:             make(map[string]*OpenFileInfo),
 80	}
 81
 82	// Start the LSP server process
 83	if err := cmd.Start(); err != nil {
 84		return nil, fmt.Errorf("failed to start LSP server: %w", err)
 85	}
 86
 87	// Handle stderr in a separate goroutine
 88	go func() {
 89		scanner := bufio.NewScanner(stderr)
 90		for scanner.Scan() {
 91			fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
 92		}
 93		if err := scanner.Err(); err != nil {
 94			fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
 95		}
 96	}()
 97
 98	// Start message handling loop
 99	go client.handleMessages()
100
101	return client, nil
102}
103
104func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
105	c.notificationMu.Lock()
106	defer c.notificationMu.Unlock()
107	c.notificationHandlers[method] = handler
108}
109
110func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
111	c.serverHandlersMu.Lock()
112	defer c.serverHandlersMu.Unlock()
113	c.serverRequestHandlers[method] = handler
114}
115
116func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
117	initParams := &protocol.InitializeParams{
118		WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
119			WorkspaceFolders: []protocol.WorkspaceFolder{
120				{
121					URI:  protocol.URI("file://" + workspaceDir),
122					Name: workspaceDir,
123				},
124			},
125		},
126
127		XInitializeParams: protocol.XInitializeParams{
128			ProcessID: int32(os.Getpid()),
129			ClientInfo: &protocol.ClientInfo{
130				Name:    "mcp-language-server",
131				Version: "0.1.0",
132			},
133			RootPath: workspaceDir,
134			RootURI:  protocol.DocumentUri("file://" + workspaceDir),
135			Capabilities: protocol.ClientCapabilities{
136				Workspace: protocol.WorkspaceClientCapabilities{
137					Configuration: true,
138					DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
139						DynamicRegistration: true,
140					},
141					DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
142						DynamicRegistration:    true,
143						RelativePatternSupport: true,
144					},
145				},
146				TextDocument: protocol.TextDocumentClientCapabilities{
147					Synchronization: &protocol.TextDocumentSyncClientCapabilities{
148						DynamicRegistration: true,
149						DidSave:             true,
150					},
151					Completion: protocol.CompletionClientCapabilities{
152						CompletionItem: protocol.ClientCompletionItemOptions{},
153					},
154					CodeLens: &protocol.CodeLensClientCapabilities{
155						DynamicRegistration: true,
156					},
157					DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
158					CodeAction: protocol.CodeActionClientCapabilities{
159						CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
160							CodeActionKind: protocol.ClientCodeActionKindOptions{
161								ValueSet: []protocol.CodeActionKind{},
162							},
163						},
164					},
165					PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
166						VersionSupport: true,
167					},
168					SemanticTokens: protocol.SemanticTokensClientCapabilities{
169						Requests: protocol.ClientSemanticTokensRequestOptions{
170							Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
171							Full:  &protocol.Or_ClientSemanticTokensRequestOptions_full{},
172						},
173						TokenTypes:     []string{},
174						TokenModifiers: []string{},
175						Formats:        []protocol.TokenFormat{},
176					},
177				},
178				Window: protocol.WindowClientCapabilities{},
179			},
180			InitializationOptions: map[string]any{
181				"codelenses": map[string]bool{
182					"generate":           true,
183					"regenerate_cgo":     true,
184					"test":               true,
185					"tidy":               true,
186					"upgrade_dependency": true,
187					"vendor":             true,
188					"vulncheck":          false,
189				},
190			},
191		},
192	}
193
194	var result protocol.InitializeResult
195	if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
196		return nil, fmt.Errorf("initialize failed: %w", err)
197	}
198
199	if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
200		return nil, fmt.Errorf("initialized notification failed: %w", err)
201	}
202
203	// Register handlers
204	c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
205	c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
206	c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
207	c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
208	c.RegisterNotificationHandler("textDocument/publishDiagnostics",
209		func(params json.RawMessage) { HandleDiagnostics(c, params) })
210
211	// Notify the LSP server
212	err := c.Initialized(ctx, protocol.InitializedParams{})
213	if err != nil {
214		return nil, fmt.Errorf("initialization failed: %w", err)
215	}
216
217	// LSP sepecific Initialization
218	path := strings.ToLower(c.Cmd.Path)
219	switch {
220	case strings.Contains(path, "typescript-language-server"):
221		// err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
222		// if err != nil {
223		// 	return nil, err
224		// }
225	}
226
227	return &result, nil
228}
229
230func (c *Client) Close() error {
231	// Try to close all open files first
232	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
233	defer cancel()
234
235	// Attempt to close files but continue shutdown regardless
236	c.CloseAllFiles(ctx)
237
238	// Close stdin to signal the server
239	if err := c.stdin.Close(); err != nil {
240		return fmt.Errorf("failed to close stdin: %w", err)
241	}
242
243	// Use a channel to handle the Wait with timeout
244	done := make(chan error, 1)
245	go func() {
246		done <- c.Cmd.Wait()
247	}()
248
249	// Wait for process to exit with timeout
250	select {
251	case err := <-done:
252		return err
253	case <-time.After(2 * time.Second):
254		// If we timeout, try to kill the process
255		if err := c.Cmd.Process.Kill(); err != nil {
256			return fmt.Errorf("failed to kill process: %w", err)
257		}
258		return fmt.Errorf("process killed after timeout")
259	}
260}
261
262type ServerState int
263
264const (
265	StateStarting ServerState = iota
266	StateReady
267	StateError
268)
269
270func (c *Client) WaitForServerReady(ctx context.Context) error {
271	// TODO: wait for specific messages or poll workspace/symbol
272	time.Sleep(time.Second * 1)
273	return nil
274}
275
276type OpenFileInfo struct {
277	Version int32
278	URI     protocol.DocumentUri
279}
280
281func (c *Client) OpenFile(ctx context.Context, filepath string) error {
282	uri := fmt.Sprintf("file://%s", filepath)
283
284	c.openFilesMu.Lock()
285	if _, exists := c.openFiles[uri]; exists {
286		c.openFilesMu.Unlock()
287		return nil // Already open
288	}
289	c.openFilesMu.Unlock()
290
291	// Skip files that do not exist or cannot be read
292	content, err := os.ReadFile(filepath)
293	if err != nil {
294		return fmt.Errorf("error reading file: %w", err)
295	}
296
297	params := protocol.DidOpenTextDocumentParams{
298		TextDocument: protocol.TextDocumentItem{
299			URI:        protocol.DocumentUri(uri),
300			LanguageID: DetectLanguageID(uri),
301			Version:    1,
302			Text:       string(content),
303		},
304	}
305
306	if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
307		return err
308	}
309
310	c.openFilesMu.Lock()
311	c.openFiles[uri] = &OpenFileInfo{
312		Version: 1,
313		URI:     protocol.DocumentUri(uri),
314	}
315	c.openFilesMu.Unlock()
316
317	return nil
318}
319
320func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
321	uri := fmt.Sprintf("file://%s", filepath)
322
323	content, err := os.ReadFile(filepath)
324	if err != nil {
325		return fmt.Errorf("error reading file: %w", err)
326	}
327
328	c.openFilesMu.Lock()
329	fileInfo, isOpen := c.openFiles[uri]
330	if !isOpen {
331		c.openFilesMu.Unlock()
332		return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
333	}
334
335	// Increment version
336	fileInfo.Version++
337	version := fileInfo.Version
338	c.openFilesMu.Unlock()
339
340	params := protocol.DidChangeTextDocumentParams{
341		TextDocument: protocol.VersionedTextDocumentIdentifier{
342			TextDocumentIdentifier: protocol.TextDocumentIdentifier{
343				URI: protocol.DocumentUri(uri),
344			},
345			Version: version,
346		},
347		ContentChanges: []protocol.TextDocumentContentChangeEvent{
348			{
349				Value: protocol.TextDocumentContentChangeWholeDocument{
350					Text: string(content),
351				},
352			},
353		},
354	}
355
356	return c.Notify(ctx, "textDocument/didChange", params)
357}
358
359func (c *Client) CloseFile(ctx context.Context, filepath string) error {
360	uri := fmt.Sprintf("file://%s", filepath)
361
362	c.openFilesMu.Lock()
363	if _, exists := c.openFiles[uri]; !exists {
364		c.openFilesMu.Unlock()
365		return nil // Already closed
366	}
367	c.openFilesMu.Unlock()
368
369	params := protocol.DidCloseTextDocumentParams{
370		TextDocument: protocol.TextDocumentIdentifier{
371			URI: protocol.DocumentUri(uri),
372		},
373	}
374	log.Println("Closing", params.TextDocument.URI.Dir())
375	if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
376		return err
377	}
378
379	c.openFilesMu.Lock()
380	delete(c.openFiles, uri)
381	c.openFilesMu.Unlock()
382
383	return nil
384}
385
386func (c *Client) IsFileOpen(filepath string) bool {
387	uri := fmt.Sprintf("file://%s", filepath)
388	c.openFilesMu.RLock()
389	defer c.openFilesMu.RUnlock()
390	_, exists := c.openFiles[uri]
391	return exists
392}
393
394// CloseAllFiles closes all currently open files
395func (c *Client) CloseAllFiles(ctx context.Context) {
396	c.openFilesMu.Lock()
397	filesToClose := make([]string, 0, len(c.openFiles))
398
399	// First collect all URIs that need to be closed
400	for uri := range c.openFiles {
401		// Convert URI back to file path by trimming "file://" prefix
402		filePath := strings.TrimPrefix(uri, "file://")
403		filesToClose = append(filesToClose, filePath)
404	}
405	c.openFilesMu.Unlock()
406
407	// Then close them all
408	for _, filePath := range filesToClose {
409		err := c.CloseFile(ctx, filePath)
410		if err != nil && debug {
411			log.Printf("Error closing file %s: %v", filePath, err)
412		}
413	}
414
415	if debug {
416		log.Printf("Closed %d files", len(filesToClose))
417	}
418}
419
420func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
421	c.diagnosticsMu.RLock()
422	defer c.diagnosticsMu.RUnlock()
423
424	return c.diagnostics[uri]
425}
426
427func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
428	return c.diagnostics
429}