definitions.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/lsp"
 12	"github.com/charmbracelet/crush/internal/lsp/protocol"
 13)
 14
 15type DefinitionsTool struct {
 16	lspClients map[string]*lsp.Client
 17}
 18
 19const (
 20	DefinitionsToolName = "definitions"
 21)
 22
 23func NewDefinitionsTool(lspClients map[string]*lsp.Client) BaseTool {
 24	return &DefinitionsTool{
 25		lspClients: lspClients,
 26	}
 27}
 28
 29func (t *DefinitionsTool) Name() string {
 30	return DefinitionsToolName
 31}
 32
 33func (t *DefinitionsTool) Info() ToolInfo {
 34	return ToolInfo{
 35		Name: DefinitionsToolName,
 36		Description: `Gets all symbol definitions from a file using the appropriate LSP server.
 37
 38WHEN TO USE THIS TOOL:
 39- Use when you need to understand the structure and symbols in a code file
 40- Helpful for exploring unfamiliar codebases and understanding what's defined in a file
 41- Good for finding functions, classes, variables, types, interfaces, and other symbols
 42- Useful before making changes to understand existing code structure
 43
 44HOW TO USE:
 45- Provide the path to a source code file
 46- The tool will automatically select the appropriate LSP server based on file extension
 47- Results show symbol names, types, locations, and hierarchical relationships
 48
 49FEATURES:
 50- Supports multiple programming languages (Go, TypeScript, JavaScript, Rust, Python, etc.)
 51- Shows hierarchical symbol relationships (classes with their methods, etc.)
 52- Provides precise line number locations for each symbol
 53- Includes symbol details and documentation when available
 54
 55LIMITATIONS:
 56- Requires an active LSP server for the file's language
 57- Only works with files that the LSP server can parse
 58- Results depend on LSP server capabilities and file syntax correctness
 59
 60TIPS:
 61- Use this tool to get an overview of a file's structure before editing
 62- Combine with other tools like View to see the actual code implementation
 63- Helpful for understanding large files with many symbols`,
 64		Parameters: map[string]any{
 65			"file_path": map[string]any{
 66				"type":        "string",
 67				"description": "The path to the file to get definitions from",
 68			},
 69		},
 70		Required: []string{"file_path"},
 71	}
 72}
 73
 74type DefinitionsParams struct {
 75	FilePath string `json:"file_path"`
 76}
 77
 78type DefinitionResult struct {
 79	Name     string             `json:"name"`
 80	Kind     string             `json:"kind"`
 81	Detail   string             `json:"detail,omitempty"`
 82	Range    string             `json:"range"`
 83	Children []DefinitionResult `json:"children,omitempty"`
 84}
 85
 86func (t *DefinitionsTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 87	var definitionsParams DefinitionsParams
 88	if err := json.Unmarshal([]byte(params.Input), &definitionsParams); err != nil {
 89		return NewTextErrorResponse(fmt.Sprintf("Invalid parameters: %v", err)), nil
 90	}
 91
 92	if definitionsParams.FilePath == "" {
 93		return NewTextErrorResponse("file_path parameter is required"), nil
 94	}
 95
 96	// Check if file exists
 97	if _, err := os.Stat(definitionsParams.FilePath); err != nil {
 98		return NewTextErrorResponse(fmt.Sprintf("File does not exist: %s", definitionsParams.FilePath)), nil
 99	}
100
101	// Find the appropriate LSP client for this file
102	client, clientName, err := t.findLSPClient(definitionsParams.FilePath)
103	if err != nil {
104		return NewTextErrorResponse(fmt.Sprintf("No suitable LSP client found for file %s: %v\n\nAvailable LSP clients: %s",
105			definitionsParams.FilePath, err, t.getAvailableClients())), nil
106	}
107
108	// Check if the server is ready
109	if client.GetServerState() != lsp.StateReady {
110		return NewTextErrorResponse(fmt.Sprintf("LSP server %s is not ready (state: %v). Please wait for the server to initialize or check the server configuration.",
111			clientName, client.GetServerState())), nil
112	}
113
114	// Ensure the file is open in the LSP
115	if err := client.OpenFileOnDemand(ctx, definitionsParams.FilePath); err != nil {
116		return NewTextErrorResponse(fmt.Sprintf("Failed to open file in LSP: %v", err)), nil
117	}
118
119	// Get document symbols
120	documentURI := protocol.URIFromPath(definitionsParams.FilePath)
121	symbolParams := protocol.DocumentSymbolParams{
122		TextDocument: protocol.TextDocumentIdentifier{
123			URI: documentURI,
124		},
125	}
126
127	result, err := client.DocumentSymbol(ctx, symbolParams)
128	if err != nil {
129		return NewTextErrorResponse(fmt.Sprintf("Failed to get document symbols from %s LSP: %v\n\nThis might happen if:\n- The file has syntax errors\n- The LSP server doesn't support this file type\n- The file is not part of a recognized project structure",
130			clientName, err)), nil
131	}
132
133	// Parse the result and format it
134	definitions, err := t.parseSymbolResult(result)
135	if err != nil {
136		return NewTextErrorResponse(fmt.Sprintf("Failed to parse symbols: %v", err)), nil
137	}
138
139	if len(definitions) == 0 {
140		return NewTextResponse(fmt.Sprintf("No definitions found in file %s.\n\nThis might happen if:\n- The file is empty or contains only comments\n- The file has syntax errors preventing symbol extraction\n- The LSP server doesn't recognize symbols in this file type",
141			definitionsParams.FilePath)), nil
142	}
143
144	// Format the output
145	output := t.formatDefinitions(definitions, definitionsParams.FilePath, clientName)
146	return NewTextResponse(output), nil
147}
148
149// findLSPClient finds the most appropriate LSP client for the given file
150func (t *DefinitionsTool) findLSPClient(filePath string) (*lsp.Client, string, error) {
151	ext := strings.ToLower(filepath.Ext(filePath))
152
153	// Map file extensions to preferred LSP client types
154	preferredClients := map[string][]string{
155		".go":   {"gopls", "go"},
156		".ts":   {"typescript", "vtsls", "tsserver"},
157		".tsx":  {"typescript", "vtsls", "tsserver"},
158		".js":   {"typescript", "vtsls", "tsserver"},
159		".jsx":  {"typescript", "vtsls", "tsserver"},
160		".rs":   {"rust-analyzer", "rust"},
161		".py":   {"pyright", "pylsp", "python"},
162		".java": {"jdtls", "java"},
163		".c":    {"clangd", "ccls", "c"},
164		".cpp":  {"clangd", "ccls", "cpp", "c++"},
165		".cc":   {"clangd", "ccls", "cpp", "c++"},
166		".cxx":  {"clangd", "ccls", "cpp", "c++"},
167		".h":    {"clangd", "ccls", "c", "cpp"},
168		".hpp":  {"clangd", "ccls", "cpp", "c++"},
169		".cs":   {"omnisharp", "csharp"},
170		".php":  {"intelephense", "php"},
171		".rb":   {"solargraph", "ruby"},
172		".lua":  {"lua-language-server", "lua"},
173		".sh":   {"bash-language-server", "bash"},
174		".bash": {"bash-language-server", "bash"},
175		".zsh":  {"bash-language-server", "bash"},
176	}
177
178	// First, try to find a client that matches the preferred types for this file extension
179	if preferred, exists := preferredClients[ext]; exists {
180		for _, clientType := range preferred {
181			for name, client := range t.lspClients {
182				if strings.Contains(strings.ToLower(name), clientType) && client.GetServerState() == lsp.StateReady {
183					return client, name, nil
184				}
185			}
186		}
187	}
188
189	// If no preferred client found, try any available ready client
190	// This is a fallback for generic LSP servers that might support multiple languages
191	for name, client := range t.lspClients {
192		if client.GetServerState() == lsp.StateReady {
193			return client, name, nil
194		}
195	}
196
197	return nil, "", fmt.Errorf("no suitable LSP client found for file extension %s", ext)
198}
199
200// getAvailableClients returns a string listing all available LSP clients
201func (t *DefinitionsTool) getAvailableClients() string {
202	if len(t.lspClients) == 0 {
203		return "none"
204	}
205
206	var clients []string
207	for name := range t.lspClients {
208		clients = append(clients, name)
209	}
210	return strings.Join(clients, ", ")
211}
212
213// parseSymbolResult parses the LSP symbol result into our format
214func (t *DefinitionsTool) parseSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) ([]DefinitionResult, error) {
215	var definitions []DefinitionResult
216
217	// The result can be either []DocumentSymbol or []SymbolInformation
218	// Try to unmarshal as DocumentSymbol first (newer format with hierarchy)
219	if result.Value != nil {
220		// Convert interface{} to []byte for unmarshaling
221		resultBytes, err := json.Marshal(result.Value)
222		if err != nil {
223			return definitions, fmt.Errorf("failed to marshal result: %v", err)
224		}
225
226		// Try DocumentSymbol format first (hierarchical)
227		var docSymbols []protocol.DocumentSymbol
228		if err := json.Unmarshal(resultBytes, &docSymbols); err == nil && len(docSymbols) > 0 {
229			for _, symbol := range docSymbols {
230				definitions = append(definitions, t.convertDocumentSymbol(symbol))
231			}
232			return definitions, nil
233		}
234
235		// Try SymbolInformation format (flat list)
236		var symbolInfos []protocol.SymbolInformation
237		if err := json.Unmarshal(resultBytes, &symbolInfos); err == nil && len(symbolInfos) > 0 {
238			for _, symbol := range symbolInfos {
239				definitions = append(definitions, t.convertSymbolInformation(symbol))
240			}
241			return definitions, nil
242		}
243
244		// If both fail, try to handle as a single symbol (some servers return single objects)
245		var singleDocSymbol protocol.DocumentSymbol
246		if err := json.Unmarshal(resultBytes, &singleDocSymbol); err == nil && singleDocSymbol.Name != "" {
247			definitions = append(definitions, t.convertDocumentSymbol(singleDocSymbol))
248			return definitions, nil
249		}
250
251		var singleSymbolInfo protocol.SymbolInformation
252		if err := json.Unmarshal(resultBytes, &singleSymbolInfo); err == nil && singleSymbolInfo.Name != "" {
253			definitions = append(definitions, t.convertSymbolInformation(singleSymbolInfo))
254			return definitions, nil
255		}
256	}
257
258	return definitions, nil
259}
260
261// convertDocumentSymbol converts a DocumentSymbol to our format
262func (t *DefinitionsTool) convertDocumentSymbol(symbol protocol.DocumentSymbol) DefinitionResult {
263	result := DefinitionResult{
264		Name:   symbol.Name,
265		Kind:   t.symbolKindToString(symbol.Kind),
266		Detail: symbol.Detail,
267		Range:  t.formatRange(symbol.Range),
268	}
269
270	// Convert children recursively
271	for _, child := range symbol.Children {
272		result.Children = append(result.Children, t.convertDocumentSymbol(child))
273	}
274
275	return result
276}
277
278// convertSymbolInformation converts a SymbolInformation to our format
279func (t *DefinitionsTool) convertSymbolInformation(symbol protocol.SymbolInformation) DefinitionResult {
280	return DefinitionResult{
281		Name:  symbol.Name,
282		Kind:  t.symbolKindToString(symbol.Kind),
283		Range: t.formatLocation(symbol.Location),
284	}
285}
286
287// symbolKindToString converts SymbolKind to a readable string
288func (t *DefinitionsTool) symbolKindToString(kind protocol.SymbolKind) string {
289	switch kind {
290	case protocol.File:
291		return "File"
292	case protocol.Module:
293		return "Module"
294	case protocol.Namespace:
295		return "Namespace"
296	case protocol.Package:
297		return "Package"
298	case protocol.Class:
299		return "Class"
300	case protocol.Method:
301		return "Method"
302	case protocol.Property:
303		return "Property"
304	case protocol.Field:
305		return "Field"
306	case protocol.Constructor:
307		return "Constructor"
308	case protocol.Enum:
309		return "Enum"
310	case protocol.Interface:
311		return "Interface"
312	case protocol.Function:
313		return "Function"
314	case protocol.Variable:
315		return "Variable"
316	case protocol.Constant:
317		return "Constant"
318	case protocol.String:
319		return "String"
320	case protocol.Number:
321		return "Number"
322	case protocol.Boolean:
323		return "Boolean"
324	case protocol.Array:
325		return "Array"
326	case protocol.Object:
327		return "Object"
328	case protocol.Key:
329		return "Key"
330	case protocol.Null:
331		return "Null"
332	case protocol.EnumMember:
333		return "EnumMember"
334	case protocol.Struct:
335		return "Struct"
336	case protocol.Event:
337		return "Event"
338	case protocol.Operator:
339		return "Operator"
340	case protocol.TypeParameter:
341		return "TypeParameter"
342	default:
343		return fmt.Sprintf("Unknown(%d)", kind)
344	}
345}
346
347// formatRange formats a Range to a readable string
348func (t *DefinitionsTool) formatRange(r protocol.Range) string {
349	startLine := r.Start.Line + 1
350	endLine := r.End.Line + 1
351	
352	if startLine == endLine {
353		return fmt.Sprintf("line %d", startLine)
354	}
355	return fmt.Sprintf("lines %d-%d", startLine, endLine)
356}
357
358// formatLocation formats a Location to a readable string
359func (t *DefinitionsTool) formatLocation(loc protocol.Location) string {
360	startLine := loc.Range.Start.Line + 1
361	endLine := loc.Range.End.Line + 1
362	
363	if startLine == endLine {
364		return fmt.Sprintf("line %d", startLine)
365	}
366	return fmt.Sprintf("lines %d-%d", startLine, endLine)
367}
368
369// formatDefinitions formats the definitions into a readable output
370func (t *DefinitionsTool) formatDefinitions(definitions []DefinitionResult, filePath, clientName string) string {
371	var output strings.Builder
372
373	output.WriteString(fmt.Sprintf("# Definitions in %s\n", filePath))
374	output.WriteString(fmt.Sprintf("*Using %s LSP server*\n\n", clientName))
375
376	if len(definitions) == 0 {
377		output.WriteString("No definitions found in this file.\n")
378		return output.String()
379	}
380
381	// Group definitions by kind for better organization
382	kindGroups := make(map[string][]DefinitionResult)
383	for _, def := range definitions {
384		kindGroups[def.Kind] = append(kindGroups[def.Kind], def)
385	}
386
387	// Define order of kinds for consistent output
388	kindOrder := []string{"Class", "Interface", "Struct", "Enum", "Function", "Method", "Variable", "Constant", "Property", "Field"}
389
390	// Write definitions grouped by kind
391	for _, kind := range kindOrder {
392		if defs, exists := kindGroups[kind]; exists {
393			output.WriteString(fmt.Sprintf("## %ss\n", kind))
394			for _, def := range defs {
395				t.writeDefinition(&output, def, 0)
396			}
397			output.WriteString("\n")
398			delete(kindGroups, kind)
399		}
400	}
401
402	// Write any remaining kinds not in the predefined order
403	for kind, defs := range kindGroups {
404		output.WriteString(fmt.Sprintf("## %ss\n", kind))
405		for _, def := range defs {
406			t.writeDefinition(&output, def, 0)
407		}
408		output.WriteString("\n")
409	}
410
411	return output.String()
412}
413
414// writeDefinition recursively writes a definition and its children
415func (t *DefinitionsTool) writeDefinition(output *strings.Builder, def DefinitionResult, indent int) {
416	indentStr := strings.Repeat("  ", indent)
417
418	// Format the main definition line
419	output.WriteString(fmt.Sprintf("%s- **%s** (%s)", indentStr, def.Name, def.Range))
420
421	if def.Detail != "" {
422		output.WriteString(fmt.Sprintf(" - %s", def.Detail))
423	}
424	output.WriteString("\n")
425
426	// Write children with increased indentation
427	if len(def.Children) > 0 {
428		for _, child := range def.Children {
429			t.writeDefinition(output, child, indent+1)
430		}
431	}
432}
433