@@ -324,6 +324,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
"ls",
"sourcegraph",
"view",
+ "definitions",
},
// NO MCPs or LSPs by default
AllowedMCP: map[string][]string{},
@@ -1264,7 +1265,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) {
}
validTools := []string{
- "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
+ "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", "definitions", "diagnostics",
}
for agentID, agent := range c.Agents {
@@ -1342,7 +1343,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) {
if agent.Description != "An agent that helps with searching for context and finding implementation details." {
errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
}
- expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
+ expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view", "definitions"}
if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
}
@@ -0,0 +1,433 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/lsp"
+ "github.com/charmbracelet/crush/internal/lsp/protocol"
+)
+
+type DefinitionsTool struct {
+ lspClients map[string]*lsp.Client
+}
+
+const (
+ DefinitionsToolName = "definitions"
+)
+
+func NewDefinitionsTool(lspClients map[string]*lsp.Client) BaseTool {
+ return &DefinitionsTool{
+ lspClients: lspClients,
+ }
+}
+
+func (t *DefinitionsTool) Name() string {
+ return DefinitionsToolName
+}
+
+func (t *DefinitionsTool) Info() ToolInfo {
+ return ToolInfo{
+ Name: DefinitionsToolName,
+ Description: `Gets all symbol definitions from a file using the appropriate LSP server.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to understand the structure and symbols in a code file
+- Helpful for exploring unfamiliar codebases and understanding what's defined in a file
+- Good for finding functions, classes, variables, types, interfaces, and other symbols
+- Useful before making changes to understand existing code structure
+
+HOW TO USE:
+- Provide the path to a source code file
+- The tool will automatically select the appropriate LSP server based on file extension
+- Results show symbol names, types, locations, and hierarchical relationships
+
+FEATURES:
+- Supports multiple programming languages (Go, TypeScript, JavaScript, Rust, Python, etc.)
+- Shows hierarchical symbol relationships (classes with their methods, etc.)
+- Provides precise line number locations for each symbol
+- Includes symbol details and documentation when available
+
+LIMITATIONS:
+- Requires an active LSP server for the file's language
+- Only works with files that the LSP server can parse
+- Results depend on LSP server capabilities and file syntax correctness
+
+TIPS:
+- Use this tool to get an overview of a file's structure before editing
+- Combine with other tools like View to see the actual code implementation
+- Helpful for understanding large files with many symbols`,
+ Parameters: map[string]any{
+ "file_path": map[string]any{
+ "type": "string",
+ "description": "The path to the file to get definitions from",
+ },
+ },
+ Required: []string{"file_path"},
+ }
+}
+
+type DefinitionsParams struct {
+ FilePath string `json:"file_path"`
+}
+
+type DefinitionResult struct {
+ Name string `json:"name"`
+ Kind string `json:"kind"`
+ Detail string `json:"detail,omitempty"`
+ Range string `json:"range"`
+ Children []DefinitionResult `json:"children,omitempty"`
+}
+
+func (t *DefinitionsTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
+ var definitionsParams DefinitionsParams
+ if err := json.Unmarshal([]byte(params.Input), &definitionsParams); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("Invalid parameters: %v", err)), nil
+ }
+
+ if definitionsParams.FilePath == "" {
+ return NewTextErrorResponse("file_path parameter is required"), nil
+ }
+
+ // Check if file exists
+ if _, err := os.Stat(definitionsParams.FilePath); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("File does not exist: %s", definitionsParams.FilePath)), nil
+ }
+
+ // Find the appropriate LSP client for this file
+ client, clientName, err := t.findLSPClient(definitionsParams.FilePath)
+ if err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("No suitable LSP client found for file %s: %v\n\nAvailable LSP clients: %s",
+ definitionsParams.FilePath, err, t.getAvailableClients())), nil
+ }
+
+ // Check if the server is ready
+ if client.GetServerState() != lsp.StateReady {
+ return NewTextErrorResponse(fmt.Sprintf("LSP server %s is not ready (state: %v). Please wait for the server to initialize or check the server configuration.",
+ clientName, client.GetServerState())), nil
+ }
+
+ // Ensure the file is open in the LSP
+ if err := client.OpenFileOnDemand(ctx, definitionsParams.FilePath); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("Failed to open file in LSP: %v", err)), nil
+ }
+
+ // Get document symbols
+ documentURI := protocol.URIFromPath(definitionsParams.FilePath)
+ symbolParams := protocol.DocumentSymbolParams{
+ TextDocument: protocol.TextDocumentIdentifier{
+ URI: documentURI,
+ },
+ }
+
+ result, err := client.DocumentSymbol(ctx, symbolParams)
+ if err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("Failed to get document symbols from %s LSP: %v\n\nThis might happen if:\n- The file has syntax errors\n- The LSP server doesn't support this file type\n- The file is not part of a recognized project structure",
+ clientName, err)), nil
+ }
+
+ // Parse the result and format it
+ definitions, err := t.parseSymbolResult(result)
+ if err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("Failed to parse symbols: %v", err)), nil
+ }
+
+ if len(definitions) == 0 {
+ return NewTextResponse(fmt.Sprintf("No definitions found in file %s.\n\nThis might happen if:\n- The file is empty or contains only comments\n- The file has syntax errors preventing symbol extraction\n- The LSP server doesn't recognize symbols in this file type",
+ definitionsParams.FilePath)), nil
+ }
+
+ // Format the output
+ output := t.formatDefinitions(definitions, definitionsParams.FilePath, clientName)
+ return NewTextResponse(output), nil
+}
+
+// findLSPClient finds the most appropriate LSP client for the given file
+func (t *DefinitionsTool) findLSPClient(filePath string) (*lsp.Client, string, error) {
+ ext := strings.ToLower(filepath.Ext(filePath))
+
+ // Map file extensions to preferred LSP client types
+ preferredClients := map[string][]string{
+ ".go": {"gopls", "go"},
+ ".ts": {"typescript", "vtsls", "tsserver"},
+ ".tsx": {"typescript", "vtsls", "tsserver"},
+ ".js": {"typescript", "vtsls", "tsserver"},
+ ".jsx": {"typescript", "vtsls", "tsserver"},
+ ".rs": {"rust-analyzer", "rust"},
+ ".py": {"pyright", "pylsp", "python"},
+ ".java": {"jdtls", "java"},
+ ".c": {"clangd", "ccls", "c"},
+ ".cpp": {"clangd", "ccls", "cpp", "c++"},
+ ".cc": {"clangd", "ccls", "cpp", "c++"},
+ ".cxx": {"clangd", "ccls", "cpp", "c++"},
+ ".h": {"clangd", "ccls", "c", "cpp"},
+ ".hpp": {"clangd", "ccls", "cpp", "c++"},
+ ".cs": {"omnisharp", "csharp"},
+ ".php": {"intelephense", "php"},
+ ".rb": {"solargraph", "ruby"},
+ ".lua": {"lua-language-server", "lua"},
+ ".sh": {"bash-language-server", "bash"},
+ ".bash": {"bash-language-server", "bash"},
+ ".zsh": {"bash-language-server", "bash"},
+ }
+
+ // First, try to find a client that matches the preferred types for this file extension
+ if preferred, exists := preferredClients[ext]; exists {
+ for _, clientType := range preferred {
+ for name, client := range t.lspClients {
+ if strings.Contains(strings.ToLower(name), clientType) && client.GetServerState() == lsp.StateReady {
+ return client, name, nil
+ }
+ }
+ }
+ }
+
+ // If no preferred client found, try any available ready client
+ // This is a fallback for generic LSP servers that might support multiple languages
+ for name, client := range t.lspClients {
+ if client.GetServerState() == lsp.StateReady {
+ return client, name, nil
+ }
+ }
+
+ return nil, "", fmt.Errorf("no suitable LSP client found for file extension %s", ext)
+}
+
+// getAvailableClients returns a string listing all available LSP clients
+func (t *DefinitionsTool) getAvailableClients() string {
+ if len(t.lspClients) == 0 {
+ return "none"
+ }
+
+ var clients []string
+ for name := range t.lspClients {
+ clients = append(clients, name)
+ }
+ return strings.Join(clients, ", ")
+}
+
+// parseSymbolResult parses the LSP symbol result into our format
+func (t *DefinitionsTool) parseSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) ([]DefinitionResult, error) {
+ var definitions []DefinitionResult
+
+ // The result can be either []DocumentSymbol or []SymbolInformation
+ // Try to unmarshal as DocumentSymbol first (newer format with hierarchy)
+ if result.Value != nil {
+ // Convert interface{} to []byte for unmarshaling
+ resultBytes, err := json.Marshal(result.Value)
+ if err != nil {
+ return definitions, fmt.Errorf("failed to marshal result: %v", err)
+ }
+
+ // Try DocumentSymbol format first (hierarchical)
+ var docSymbols []protocol.DocumentSymbol
+ if err := json.Unmarshal(resultBytes, &docSymbols); err == nil && len(docSymbols) > 0 {
+ for _, symbol := range docSymbols {
+ definitions = append(definitions, t.convertDocumentSymbol(symbol))
+ }
+ return definitions, nil
+ }
+
+ // Try SymbolInformation format (flat list)
+ var symbolInfos []protocol.SymbolInformation
+ if err := json.Unmarshal(resultBytes, &symbolInfos); err == nil && len(symbolInfos) > 0 {
+ for _, symbol := range symbolInfos {
+ definitions = append(definitions, t.convertSymbolInformation(symbol))
+ }
+ return definitions, nil
+ }
+
+ // If both fail, try to handle as a single symbol (some servers return single objects)
+ var singleDocSymbol protocol.DocumentSymbol
+ if err := json.Unmarshal(resultBytes, &singleDocSymbol); err == nil && singleDocSymbol.Name != "" {
+ definitions = append(definitions, t.convertDocumentSymbol(singleDocSymbol))
+ return definitions, nil
+ }
+
+ var singleSymbolInfo protocol.SymbolInformation
+ if err := json.Unmarshal(resultBytes, &singleSymbolInfo); err == nil && singleSymbolInfo.Name != "" {
+ definitions = append(definitions, t.convertSymbolInformation(singleSymbolInfo))
+ return definitions, nil
+ }
+ }
+
+ return definitions, nil
+}
+
+// convertDocumentSymbol converts a DocumentSymbol to our format
+func (t *DefinitionsTool) convertDocumentSymbol(symbol protocol.DocumentSymbol) DefinitionResult {
+ result := DefinitionResult{
+ Name: symbol.Name,
+ Kind: t.symbolKindToString(symbol.Kind),
+ Detail: symbol.Detail,
+ Range: t.formatRange(symbol.Range),
+ }
+
+ // Convert children recursively
+ for _, child := range symbol.Children {
+ result.Children = append(result.Children, t.convertDocumentSymbol(child))
+ }
+
+ return result
+}
+
+// convertSymbolInformation converts a SymbolInformation to our format
+func (t *DefinitionsTool) convertSymbolInformation(symbol protocol.SymbolInformation) DefinitionResult {
+ return DefinitionResult{
+ Name: symbol.Name,
+ Kind: t.symbolKindToString(symbol.Kind),
+ Range: t.formatLocation(symbol.Location),
+ }
+}
+
+// symbolKindToString converts SymbolKind to a readable string
+func (t *DefinitionsTool) symbolKindToString(kind protocol.SymbolKind) string {
+ switch kind {
+ case protocol.File:
+ return "File"
+ case protocol.Module:
+ return "Module"
+ case protocol.Namespace:
+ return "Namespace"
+ case protocol.Package:
+ return "Package"
+ case protocol.Class:
+ return "Class"
+ case protocol.Method:
+ return "Method"
+ case protocol.Property:
+ return "Property"
+ case protocol.Field:
+ return "Field"
+ case protocol.Constructor:
+ return "Constructor"
+ case protocol.Enum:
+ return "Enum"
+ case protocol.Interface:
+ return "Interface"
+ case protocol.Function:
+ return "Function"
+ case protocol.Variable:
+ return "Variable"
+ case protocol.Constant:
+ return "Constant"
+ case protocol.String:
+ return "String"
+ case protocol.Number:
+ return "Number"
+ case protocol.Boolean:
+ return "Boolean"
+ case protocol.Array:
+ return "Array"
+ case protocol.Object:
+ return "Object"
+ case protocol.Key:
+ return "Key"
+ case protocol.Null:
+ return "Null"
+ case protocol.EnumMember:
+ return "EnumMember"
+ case protocol.Struct:
+ return "Struct"
+ case protocol.Event:
+ return "Event"
+ case protocol.Operator:
+ return "Operator"
+ case protocol.TypeParameter:
+ return "TypeParameter"
+ default:
+ return fmt.Sprintf("Unknown(%d)", kind)
+ }
+}
+
+// formatRange formats a Range to a readable string
+func (t *DefinitionsTool) formatRange(r protocol.Range) string {
+ startLine := r.Start.Line + 1
+ endLine := r.End.Line + 1
+
+ if startLine == endLine {
+ return fmt.Sprintf("line %d", startLine)
+ }
+ return fmt.Sprintf("lines %d-%d", startLine, endLine)
+}
+
+// formatLocation formats a Location to a readable string
+func (t *DefinitionsTool) formatLocation(loc protocol.Location) string {
+ startLine := loc.Range.Start.Line + 1
+ endLine := loc.Range.End.Line + 1
+
+ if startLine == endLine {
+ return fmt.Sprintf("line %d", startLine)
+ }
+ return fmt.Sprintf("lines %d-%d", startLine, endLine)
+}
+
+// formatDefinitions formats the definitions into a readable output
+func (t *DefinitionsTool) formatDefinitions(definitions []DefinitionResult, filePath, clientName string) string {
+ var output strings.Builder
+
+ output.WriteString(fmt.Sprintf("# Definitions in %s\n", filePath))
+ output.WriteString(fmt.Sprintf("*Using %s LSP server*\n\n", clientName))
+
+ if len(definitions) == 0 {
+ output.WriteString("No definitions found in this file.\n")
+ return output.String()
+ }
+
+ // Group definitions by kind for better organization
+ kindGroups := make(map[string][]DefinitionResult)
+ for _, def := range definitions {
+ kindGroups[def.Kind] = append(kindGroups[def.Kind], def)
+ }
+
+ // Define order of kinds for consistent output
+ kindOrder := []string{"Class", "Interface", "Struct", "Enum", "Function", "Method", "Variable", "Constant", "Property", "Field"}
+
+ // Write definitions grouped by kind
+ for _, kind := range kindOrder {
+ if defs, exists := kindGroups[kind]; exists {
+ output.WriteString(fmt.Sprintf("## %ss\n", kind))
+ for _, def := range defs {
+ t.writeDefinition(&output, def, 0)
+ }
+ output.WriteString("\n")
+ delete(kindGroups, kind)
+ }
+ }
+
+ // Write any remaining kinds not in the predefined order
+ for kind, defs := range kindGroups {
+ output.WriteString(fmt.Sprintf("## %ss\n", kind))
+ for _, def := range defs {
+ t.writeDefinition(&output, def, 0)
+ }
+ output.WriteString("\n")
+ }
+
+ return output.String()
+}
+
+// writeDefinition recursively writes a definition and its children
+func (t *DefinitionsTool) writeDefinition(output *strings.Builder, def DefinitionResult, indent int) {
+ indentStr := strings.Repeat(" ", indent)
+
+ // Format the main definition line
+ output.WriteString(fmt.Sprintf("%s- **%s** (%s)", indentStr, def.Name, def.Range))
+
+ if def.Detail != "" {
+ output.WriteString(fmt.Sprintf(" - %s", def.Detail))
+ }
+ output.WriteString("\n")
+
+ // Write children with increased indentation
+ if len(def.Children) > 0 {
+ for _, child := range def.Children {
+ t.writeDefinition(output, child, indent+1)
+ }
+ }
+}
+