1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "os"
8 "path/filepath"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/lsp"
12 "github.com/charmbracelet/crush/internal/lsp/protocol"
13)
14
15type DefinitionsTool struct {
16 lspClients map[string]*lsp.Client
17}
18
19const (
20 DefinitionsToolName = "definitions"
21)
22
23func NewDefinitionsTool(lspClients map[string]*lsp.Client) BaseTool {
24 return &DefinitionsTool{
25 lspClients: lspClients,
26 }
27}
28
29func (t *DefinitionsTool) Name() string {
30 return DefinitionsToolName
31}
32
33func (t *DefinitionsTool) Info() ToolInfo {
34 return ToolInfo{
35 Name: DefinitionsToolName,
36 Description: `Gets all symbol definitions from a file using the appropriate LSP server.
37
38WHEN TO USE THIS TOOL:
39- Use when you need to understand the structure and symbols in a code file
40- Helpful for exploring unfamiliar codebases and understanding what's defined in a file
41- Good for finding functions, classes, variables, types, interfaces, and other symbols
42- Useful before making changes to understand existing code structure
43
44HOW TO USE:
45- Provide the path to a source code file
46- The tool will automatically select the appropriate LSP server based on file extension
47- Results show symbol names, types, locations, and hierarchical relationships
48
49FEATURES:
50- Supports multiple programming languages (Go, TypeScript, JavaScript, Rust, Python, etc.)
51- Shows hierarchical symbol relationships (classes with their methods, etc.)
52- Provides precise line number locations for each symbol
53- Includes symbol details and documentation when available
54
55LIMITATIONS:
56- Requires an active LSP server for the file's language
57- Only works with files that the LSP server can parse
58- Results depend on LSP server capabilities and file syntax correctness
59
60TIPS:
61- Use this tool to get an overview of a file's structure before editing
62- Combine with other tools like View to see the actual code implementation
63- Helpful for understanding large files with many symbols`,
64 Parameters: map[string]any{
65 "file_path": map[string]any{
66 "type": "string",
67 "description": "The path to the file to get definitions from",
68 },
69 },
70 Required: []string{"file_path"},
71 }
72}
73
74type DefinitionsParams struct {
75 FilePath string `json:"file_path"`
76}
77
78type DefinitionResult struct {
79 Name string `json:"name"`
80 Kind string `json:"kind"`
81 Detail string `json:"detail,omitempty"`
82 Range string `json:"range"`
83 Children []DefinitionResult `json:"children,omitempty"`
84}
85
86func (t *DefinitionsTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
87 var definitionsParams DefinitionsParams
88 if err := json.Unmarshal([]byte(params.Input), &definitionsParams); err != nil {
89 return NewTextErrorResponse(fmt.Sprintf("Invalid parameters: %v", err)), nil
90 }
91
92 if definitionsParams.FilePath == "" {
93 return NewTextErrorResponse("file_path parameter is required"), nil
94 }
95
96 // Check if file exists
97 if _, err := os.Stat(definitionsParams.FilePath); err != nil {
98 return NewTextErrorResponse(fmt.Sprintf("File does not exist: %s", definitionsParams.FilePath)), nil
99 }
100
101 // Find the appropriate LSP client for this file
102 client, clientName, err := t.findLSPClient(definitionsParams.FilePath)
103 if err != nil {
104 return NewTextErrorResponse(fmt.Sprintf("No suitable LSP client found for file %s: %v\n\nAvailable LSP clients: %s",
105 definitionsParams.FilePath, err, t.getAvailableClients())), nil
106 }
107
108 // Check if the server is ready
109 if client.GetServerState() != lsp.StateReady {
110 return NewTextErrorResponse(fmt.Sprintf("LSP server %s is not ready (state: %v). Please wait for the server to initialize or check the server configuration.",
111 clientName, client.GetServerState())), nil
112 }
113
114 // Ensure the file is open in the LSP
115 if err := client.OpenFileOnDemand(ctx, definitionsParams.FilePath); err != nil {
116 return NewTextErrorResponse(fmt.Sprintf("Failed to open file in LSP: %v", err)), nil
117 }
118
119 // Get document symbols
120 documentURI := protocol.URIFromPath(definitionsParams.FilePath)
121 symbolParams := protocol.DocumentSymbolParams{
122 TextDocument: protocol.TextDocumentIdentifier{
123 URI: documentURI,
124 },
125 }
126
127 result, err := client.DocumentSymbol(ctx, symbolParams)
128 if err != nil {
129 return NewTextErrorResponse(fmt.Sprintf("Failed to get document symbols from %s LSP: %v\n\nThis might happen if:\n- The file has syntax errors\n- The LSP server doesn't support this file type\n- The file is not part of a recognized project structure",
130 clientName, err)), nil
131 }
132
133 // Parse the result and format it
134 definitions, err := t.parseSymbolResult(result)
135 if err != nil {
136 return NewTextErrorResponse(fmt.Sprintf("Failed to parse symbols: %v", err)), nil
137 }
138
139 if len(definitions) == 0 {
140 return NewTextResponse(fmt.Sprintf("No definitions found in file %s.\n\nThis might happen if:\n- The file is empty or contains only comments\n- The file has syntax errors preventing symbol extraction\n- The LSP server doesn't recognize symbols in this file type",
141 definitionsParams.FilePath)), nil
142 }
143
144 // Format the output
145 output := t.formatDefinitions(definitions, definitionsParams.FilePath, clientName)
146 return NewTextResponse(output), nil
147}
148
149// findLSPClient finds the most appropriate LSP client for the given file
150func (t *DefinitionsTool) findLSPClient(filePath string) (*lsp.Client, string, error) {
151 ext := strings.ToLower(filepath.Ext(filePath))
152
153 // Map file extensions to preferred LSP client types
154 preferredClients := map[string][]string{
155 ".go": {"gopls", "go"},
156 ".ts": {"typescript", "vtsls", "tsserver"},
157 ".tsx": {"typescript", "vtsls", "tsserver"},
158 ".js": {"typescript", "vtsls", "tsserver"},
159 ".jsx": {"typescript", "vtsls", "tsserver"},
160 ".rs": {"rust-analyzer", "rust"},
161 ".py": {"pyright", "pylsp", "python"},
162 ".java": {"jdtls", "java"},
163 ".c": {"clangd", "ccls", "c"},
164 ".cpp": {"clangd", "ccls", "cpp", "c++"},
165 ".cc": {"clangd", "ccls", "cpp", "c++"},
166 ".cxx": {"clangd", "ccls", "cpp", "c++"},
167 ".h": {"clangd", "ccls", "c", "cpp"},
168 ".hpp": {"clangd", "ccls", "cpp", "c++"},
169 ".cs": {"omnisharp", "csharp"},
170 ".php": {"intelephense", "php"},
171 ".rb": {"solargraph", "ruby"},
172 ".lua": {"lua-language-server", "lua"},
173 ".sh": {"bash-language-server", "bash"},
174 ".bash": {"bash-language-server", "bash"},
175 ".zsh": {"bash-language-server", "bash"},
176 }
177
178 // First, try to find a client that matches the preferred types for this file extension
179 if preferred, exists := preferredClients[ext]; exists {
180 for _, clientType := range preferred {
181 for name, client := range t.lspClients {
182 if strings.Contains(strings.ToLower(name), clientType) && client.GetServerState() == lsp.StateReady {
183 return client, name, nil
184 }
185 }
186 }
187 }
188
189 // If no preferred client found, try any available ready client
190 // This is a fallback for generic LSP servers that might support multiple languages
191 for name, client := range t.lspClients {
192 if client.GetServerState() == lsp.StateReady {
193 return client, name, nil
194 }
195 }
196
197 return nil, "", fmt.Errorf("no suitable LSP client found for file extension %s", ext)
198}
199
200// getAvailableClients returns a string listing all available LSP clients
201func (t *DefinitionsTool) getAvailableClients() string {
202 if len(t.lspClients) == 0 {
203 return "none"
204 }
205
206 var clients []string
207 for name := range t.lspClients {
208 clients = append(clients, name)
209 }
210 return strings.Join(clients, ", ")
211}
212
213// parseSymbolResult parses the LSP symbol result into our format
214func (t *DefinitionsTool) parseSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) ([]DefinitionResult, error) {
215 var definitions []DefinitionResult
216
217 // The result can be either []DocumentSymbol or []SymbolInformation
218 // Try to unmarshal as DocumentSymbol first (newer format with hierarchy)
219 if result.Value != nil {
220 // Convert interface{} to []byte for unmarshaling
221 resultBytes, err := json.Marshal(result.Value)
222 if err != nil {
223 return definitions, fmt.Errorf("failed to marshal result: %v", err)
224 }
225
226 // Try DocumentSymbol format first (hierarchical)
227 var docSymbols []protocol.DocumentSymbol
228 if err := json.Unmarshal(resultBytes, &docSymbols); err == nil && len(docSymbols) > 0 {
229 for _, symbol := range docSymbols {
230 definitions = append(definitions, t.convertDocumentSymbol(symbol))
231 }
232 return definitions, nil
233 }
234
235 // Try SymbolInformation format (flat list)
236 var symbolInfos []protocol.SymbolInformation
237 if err := json.Unmarshal(resultBytes, &symbolInfos); err == nil && len(symbolInfos) > 0 {
238 for _, symbol := range symbolInfos {
239 definitions = append(definitions, t.convertSymbolInformation(symbol))
240 }
241 return definitions, nil
242 }
243
244 // If both fail, try to handle as a single symbol (some servers return single objects)
245 var singleDocSymbol protocol.DocumentSymbol
246 if err := json.Unmarshal(resultBytes, &singleDocSymbol); err == nil && singleDocSymbol.Name != "" {
247 definitions = append(definitions, t.convertDocumentSymbol(singleDocSymbol))
248 return definitions, nil
249 }
250
251 var singleSymbolInfo protocol.SymbolInformation
252 if err := json.Unmarshal(resultBytes, &singleSymbolInfo); err == nil && singleSymbolInfo.Name != "" {
253 definitions = append(definitions, t.convertSymbolInformation(singleSymbolInfo))
254 return definitions, nil
255 }
256 }
257
258 return definitions, nil
259}
260
261// convertDocumentSymbol converts a DocumentSymbol to our format
262func (t *DefinitionsTool) convertDocumentSymbol(symbol protocol.DocumentSymbol) DefinitionResult {
263 result := DefinitionResult{
264 Name: symbol.Name,
265 Kind: t.symbolKindToString(symbol.Kind),
266 Detail: symbol.Detail,
267 Range: t.formatRange(symbol.Range),
268 }
269
270 // Convert children recursively
271 for _, child := range symbol.Children {
272 result.Children = append(result.Children, t.convertDocumentSymbol(child))
273 }
274
275 return result
276}
277
278// convertSymbolInformation converts a SymbolInformation to our format
279func (t *DefinitionsTool) convertSymbolInformation(symbol protocol.SymbolInformation) DefinitionResult {
280 return DefinitionResult{
281 Name: symbol.Name,
282 Kind: t.symbolKindToString(symbol.Kind),
283 Range: t.formatLocation(symbol.Location),
284 }
285}
286
287// symbolKindToString converts SymbolKind to a readable string
288func (t *DefinitionsTool) symbolKindToString(kind protocol.SymbolKind) string {
289 switch kind {
290 case protocol.File:
291 return "File"
292 case protocol.Module:
293 return "Module"
294 case protocol.Namespace:
295 return "Namespace"
296 case protocol.Package:
297 return "Package"
298 case protocol.Class:
299 return "Class"
300 case protocol.Method:
301 return "Method"
302 case protocol.Property:
303 return "Property"
304 case protocol.Field:
305 return "Field"
306 case protocol.Constructor:
307 return "Constructor"
308 case protocol.Enum:
309 return "Enum"
310 case protocol.Interface:
311 return "Interface"
312 case protocol.Function:
313 return "Function"
314 case protocol.Variable:
315 return "Variable"
316 case protocol.Constant:
317 return "Constant"
318 case protocol.String:
319 return "String"
320 case protocol.Number:
321 return "Number"
322 case protocol.Boolean:
323 return "Boolean"
324 case protocol.Array:
325 return "Array"
326 case protocol.Object:
327 return "Object"
328 case protocol.Key:
329 return "Key"
330 case protocol.Null:
331 return "Null"
332 case protocol.EnumMember:
333 return "EnumMember"
334 case protocol.Struct:
335 return "Struct"
336 case protocol.Event:
337 return "Event"
338 case protocol.Operator:
339 return "Operator"
340 case protocol.TypeParameter:
341 return "TypeParameter"
342 default:
343 return fmt.Sprintf("Unknown(%d)", kind)
344 }
345}
346
347// formatRange formats a Range to a readable string
348func (t *DefinitionsTool) formatRange(r protocol.Range) string {
349 startLine := r.Start.Line + 1
350 endLine := r.End.Line + 1
351
352 if startLine == endLine {
353 return fmt.Sprintf("line %d", startLine)
354 }
355 return fmt.Sprintf("lines %d-%d", startLine, endLine)
356}
357
358// formatLocation formats a Location to a readable string
359func (t *DefinitionsTool) formatLocation(loc protocol.Location) string {
360 startLine := loc.Range.Start.Line + 1
361 endLine := loc.Range.End.Line + 1
362
363 if startLine == endLine {
364 return fmt.Sprintf("line %d", startLine)
365 }
366 return fmt.Sprintf("lines %d-%d", startLine, endLine)
367}
368
369// formatDefinitions formats the definitions into a readable output
370func (t *DefinitionsTool) formatDefinitions(definitions []DefinitionResult, filePath, clientName string) string {
371 var output strings.Builder
372
373 output.WriteString(fmt.Sprintf("# Definitions in %s\n", filePath))
374 output.WriteString(fmt.Sprintf("*Using %s LSP server*\n\n", clientName))
375
376 if len(definitions) == 0 {
377 output.WriteString("No definitions found in this file.\n")
378 return output.String()
379 }
380
381 // Group definitions by kind for better organization
382 kindGroups := make(map[string][]DefinitionResult)
383 for _, def := range definitions {
384 kindGroups[def.Kind] = append(kindGroups[def.Kind], def)
385 }
386
387 // Define order of kinds for consistent output
388 kindOrder := []string{"Class", "Interface", "Struct", "Enum", "Function", "Method", "Variable", "Constant", "Property", "Field"}
389
390 // Write definitions grouped by kind
391 for _, kind := range kindOrder {
392 if defs, exists := kindGroups[kind]; exists {
393 output.WriteString(fmt.Sprintf("## %ss\n", kind))
394 for _, def := range defs {
395 t.writeDefinition(&output, def, 0)
396 }
397 output.WriteString("\n")
398 delete(kindGroups, kind)
399 }
400 }
401
402 // Write any remaining kinds not in the predefined order
403 for kind, defs := range kindGroups {
404 output.WriteString(fmt.Sprintf("## %ss\n", kind))
405 for _, def := range defs {
406 t.writeDefinition(&output, def, 0)
407 }
408 output.WriteString("\n")
409 }
410
411 return output.String()
412}
413
414// writeDefinition recursively writes a definition and its children
415func (t *DefinitionsTool) writeDefinition(output *strings.Builder, def DefinitionResult, indent int) {
416 indentStr := strings.Repeat(" ", indent)
417
418 // Format the main definition line
419 output.WriteString(fmt.Sprintf("%s- **%s** (%s)", indentStr, def.Name, def.Range))
420
421 if def.Detail != "" {
422 output.WriteString(fmt.Sprintf(" - %s", def.Detail))
423 }
424 output.WriteString("\n")
425
426 // Write children with increased indentation
427 if len(def.Children) > 0 {
428 for _, child := range def.Children {
429 t.writeDefinition(output, child, indent+1)
430 }
431 }
432}
433