diff --git a/internal/config/config.go b/internal/config/config.go index b8a70505da30a1e3f274e95ef89f606d7b6db9d4..3c5fcfbd19d7518d9009197f29cc219b7955c4e2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -324,6 +324,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { "ls", "sourcegraph", "view", + "definitions", }, // NO MCPs or LSPs by default AllowedMCP: map[string][]string{}, @@ -1264,7 +1265,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) { } validTools := []string{ - "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", + "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", "definitions", "diagnostics", } for agentID, agent := range c.Agents { @@ -1342,7 +1343,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) { if agent.Description != "An agent that helps with searching for context and finding implementation details." { errors.Add(fieldPrefix+".description", "task agent description cannot be changed") } - expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view", "definitions"} if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) { errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed") } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5568c8a619287619900e8d5e5d5d44c2e85de446..eb9c4ebec9dbc8e59d57a3d733be8a0ea64918e4 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -94,8 +94,9 @@ func NewAgent( ctx := context.Background() cfg := config.Get() otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { + if len(cfg.LSP) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + otherTools = append(otherTools, tools.NewDefinitionsTool(lspClients)) } allTools := []tools.BaseTool{ diff --git a/internal/llm/tools/definitions.go b/internal/llm/tools/definitions.go new file mode 100644 index 0000000000000000000000000000000000000000..46981800ee17aec26db5d83e49e1d119653c7d76 --- /dev/null +++ b/internal/llm/tools/definitions.go @@ -0,0 +1,433 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/lsp/protocol" +) + +type DefinitionsTool struct { + lspClients map[string]*lsp.Client +} + +const ( + DefinitionsToolName = "definitions" +) + +func NewDefinitionsTool(lspClients map[string]*lsp.Client) BaseTool { + return &DefinitionsTool{ + lspClients: lspClients, + } +} + +func (t *DefinitionsTool) Name() string { + return DefinitionsToolName +} + +func (t *DefinitionsTool) Info() ToolInfo { + return ToolInfo{ + Name: DefinitionsToolName, + Description: `Gets all symbol definitions from a file using the appropriate LSP server. + +WHEN TO USE THIS TOOL: +- Use when you need to understand the structure and symbols in a code file +- Helpful for exploring unfamiliar codebases and understanding what's defined in a file +- Good for finding functions, classes, variables, types, interfaces, and other symbols +- Useful before making changes to understand existing code structure + +HOW TO USE: +- Provide the path to a source code file +- The tool will automatically select the appropriate LSP server based on file extension +- Results show symbol names, types, locations, and hierarchical relationships + +FEATURES: +- Supports multiple programming languages (Go, TypeScript, JavaScript, Rust, Python, etc.) +- Shows hierarchical symbol relationships (classes with their methods, etc.) +- Provides precise line number locations for each symbol +- Includes symbol details and documentation when available + +LIMITATIONS: +- Requires an active LSP server for the file's language +- Only works with files that the LSP server can parse +- Results depend on LSP server capabilities and file syntax correctness + +TIPS: +- Use this tool to get an overview of a file's structure before editing +- Combine with other tools like View to see the actual code implementation +- Helpful for understanding large files with many symbols`, + Parameters: map[string]any{ + "file_path": map[string]any{ + "type": "string", + "description": "The path to the file to get definitions from", + }, + }, + Required: []string{"file_path"}, + } +} + +type DefinitionsParams struct { + FilePath string `json:"file_path"` +} + +type DefinitionResult struct { + Name string `json:"name"` + Kind string `json:"kind"` + Detail string `json:"detail,omitempty"` + Range string `json:"range"` + Children []DefinitionResult `json:"children,omitempty"` +} + +func (t *DefinitionsTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) { + var definitionsParams DefinitionsParams + if err := json.Unmarshal([]byte(params.Input), &definitionsParams); err != nil { + return NewTextErrorResponse(fmt.Sprintf("Invalid parameters: %v", err)), nil + } + + if definitionsParams.FilePath == "" { + return NewTextErrorResponse("file_path parameter is required"), nil + } + + // Check if file exists + if _, err := os.Stat(definitionsParams.FilePath); err != nil { + return NewTextErrorResponse(fmt.Sprintf("File does not exist: %s", definitionsParams.FilePath)), nil + } + + // Find the appropriate LSP client for this file + client, clientName, err := t.findLSPClient(definitionsParams.FilePath) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("No suitable LSP client found for file %s: %v\n\nAvailable LSP clients: %s", + definitionsParams.FilePath, err, t.getAvailableClients())), nil + } + + // Check if the server is ready + if client.GetServerState() != lsp.StateReady { + return NewTextErrorResponse(fmt.Sprintf("LSP server %s is not ready (state: %v). Please wait for the server to initialize or check the server configuration.", + clientName, client.GetServerState())), nil + } + + // Ensure the file is open in the LSP + if err := client.OpenFileOnDemand(ctx, definitionsParams.FilePath); err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to open file in LSP: %v", err)), nil + } + + // Get document symbols + documentURI := protocol.URIFromPath(definitionsParams.FilePath) + symbolParams := protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: documentURI, + }, + } + + result, err := client.DocumentSymbol(ctx, symbolParams) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to get document symbols from %s LSP: %v\n\nThis might happen if:\n- The file has syntax errors\n- The LSP server doesn't support this file type\n- The file is not part of a recognized project structure", + clientName, err)), nil + } + + // Parse the result and format it + definitions, err := t.parseSymbolResult(result) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to parse symbols: %v", err)), nil + } + + if len(definitions) == 0 { + return NewTextResponse(fmt.Sprintf("No definitions found in file %s.\n\nThis might happen if:\n- The file is empty or contains only comments\n- The file has syntax errors preventing symbol extraction\n- The LSP server doesn't recognize symbols in this file type", + definitionsParams.FilePath)), nil + } + + // Format the output + output := t.formatDefinitions(definitions, definitionsParams.FilePath, clientName) + return NewTextResponse(output), nil +} + +// findLSPClient finds the most appropriate LSP client for the given file +func (t *DefinitionsTool) findLSPClient(filePath string) (*lsp.Client, string, error) { + ext := strings.ToLower(filepath.Ext(filePath)) + + // Map file extensions to preferred LSP client types + preferredClients := map[string][]string{ + ".go": {"gopls", "go"}, + ".ts": {"typescript", "vtsls", "tsserver"}, + ".tsx": {"typescript", "vtsls", "tsserver"}, + ".js": {"typescript", "vtsls", "tsserver"}, + ".jsx": {"typescript", "vtsls", "tsserver"}, + ".rs": {"rust-analyzer", "rust"}, + ".py": {"pyright", "pylsp", "python"}, + ".java": {"jdtls", "java"}, + ".c": {"clangd", "ccls", "c"}, + ".cpp": {"clangd", "ccls", "cpp", "c++"}, + ".cc": {"clangd", "ccls", "cpp", "c++"}, + ".cxx": {"clangd", "ccls", "cpp", "c++"}, + ".h": {"clangd", "ccls", "c", "cpp"}, + ".hpp": {"clangd", "ccls", "cpp", "c++"}, + ".cs": {"omnisharp", "csharp"}, + ".php": {"intelephense", "php"}, + ".rb": {"solargraph", "ruby"}, + ".lua": {"lua-language-server", "lua"}, + ".sh": {"bash-language-server", "bash"}, + ".bash": {"bash-language-server", "bash"}, + ".zsh": {"bash-language-server", "bash"}, + } + + // First, try to find a client that matches the preferred types for this file extension + if preferred, exists := preferredClients[ext]; exists { + for _, clientType := range preferred { + for name, client := range t.lspClients { + if strings.Contains(strings.ToLower(name), clientType) && client.GetServerState() == lsp.StateReady { + return client, name, nil + } + } + } + } + + // If no preferred client found, try any available ready client + // This is a fallback for generic LSP servers that might support multiple languages + for name, client := range t.lspClients { + if client.GetServerState() == lsp.StateReady { + return client, name, nil + } + } + + return nil, "", fmt.Errorf("no suitable LSP client found for file extension %s", ext) +} + +// getAvailableClients returns a string listing all available LSP clients +func (t *DefinitionsTool) getAvailableClients() string { + if len(t.lspClients) == 0 { + return "none" + } + + var clients []string + for name := range t.lspClients { + clients = append(clients, name) + } + return strings.Join(clients, ", ") +} + +// parseSymbolResult parses the LSP symbol result into our format +func (t *DefinitionsTool) parseSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) ([]DefinitionResult, error) { + var definitions []DefinitionResult + + // The result can be either []DocumentSymbol or []SymbolInformation + // Try to unmarshal as DocumentSymbol first (newer format with hierarchy) + if result.Value != nil { + // Convert interface{} to []byte for unmarshaling + resultBytes, err := json.Marshal(result.Value) + if err != nil { + return definitions, fmt.Errorf("failed to marshal result: %v", err) + } + + // Try DocumentSymbol format first (hierarchical) + var docSymbols []protocol.DocumentSymbol + if err := json.Unmarshal(resultBytes, &docSymbols); err == nil && len(docSymbols) > 0 { + for _, symbol := range docSymbols { + definitions = append(definitions, t.convertDocumentSymbol(symbol)) + } + return definitions, nil + } + + // Try SymbolInformation format (flat list) + var symbolInfos []protocol.SymbolInformation + if err := json.Unmarshal(resultBytes, &symbolInfos); err == nil && len(symbolInfos) > 0 { + for _, symbol := range symbolInfos { + definitions = append(definitions, t.convertSymbolInformation(symbol)) + } + return definitions, nil + } + + // If both fail, try to handle as a single symbol (some servers return single objects) + var singleDocSymbol protocol.DocumentSymbol + if err := json.Unmarshal(resultBytes, &singleDocSymbol); err == nil && singleDocSymbol.Name != "" { + definitions = append(definitions, t.convertDocumentSymbol(singleDocSymbol)) + return definitions, nil + } + + var singleSymbolInfo protocol.SymbolInformation + if err := json.Unmarshal(resultBytes, &singleSymbolInfo); err == nil && singleSymbolInfo.Name != "" { + definitions = append(definitions, t.convertSymbolInformation(singleSymbolInfo)) + return definitions, nil + } + } + + return definitions, nil +} + +// convertDocumentSymbol converts a DocumentSymbol to our format +func (t *DefinitionsTool) convertDocumentSymbol(symbol protocol.DocumentSymbol) DefinitionResult { + result := DefinitionResult{ + Name: symbol.Name, + Kind: t.symbolKindToString(symbol.Kind), + Detail: symbol.Detail, + Range: t.formatRange(symbol.Range), + } + + // Convert children recursively + for _, child := range symbol.Children { + result.Children = append(result.Children, t.convertDocumentSymbol(child)) + } + + return result +} + +// convertSymbolInformation converts a SymbolInformation to our format +func (t *DefinitionsTool) convertSymbolInformation(symbol protocol.SymbolInformation) DefinitionResult { + return DefinitionResult{ + Name: symbol.Name, + Kind: t.symbolKindToString(symbol.Kind), + Range: t.formatLocation(symbol.Location), + } +} + +// symbolKindToString converts SymbolKind to a readable string +func (t *DefinitionsTool) symbolKindToString(kind protocol.SymbolKind) string { + switch kind { + case protocol.File: + return "File" + case protocol.Module: + return "Module" + case protocol.Namespace: + return "Namespace" + case protocol.Package: + return "Package" + case protocol.Class: + return "Class" + case protocol.Method: + return "Method" + case protocol.Property: + return "Property" + case protocol.Field: + return "Field" + case protocol.Constructor: + return "Constructor" + case protocol.Enum: + return "Enum" + case protocol.Interface: + return "Interface" + case protocol.Function: + return "Function" + case protocol.Variable: + return "Variable" + case protocol.Constant: + return "Constant" + case protocol.String: + return "String" + case protocol.Number: + return "Number" + case protocol.Boolean: + return "Boolean" + case protocol.Array: + return "Array" + case protocol.Object: + return "Object" + case protocol.Key: + return "Key" + case protocol.Null: + return "Null" + case protocol.EnumMember: + return "EnumMember" + case protocol.Struct: + return "Struct" + case protocol.Event: + return "Event" + case protocol.Operator: + return "Operator" + case protocol.TypeParameter: + return "TypeParameter" + default: + return fmt.Sprintf("Unknown(%d)", kind) + } +} + +// formatRange formats a Range to a readable string +func (t *DefinitionsTool) formatRange(r protocol.Range) string { + startLine := r.Start.Line + 1 + endLine := r.End.Line + 1 + + if startLine == endLine { + return fmt.Sprintf("line %d", startLine) + } + return fmt.Sprintf("lines %d-%d", startLine, endLine) +} + +// formatLocation formats a Location to a readable string +func (t *DefinitionsTool) formatLocation(loc protocol.Location) string { + startLine := loc.Range.Start.Line + 1 + endLine := loc.Range.End.Line + 1 + + if startLine == endLine { + return fmt.Sprintf("line %d", startLine) + } + return fmt.Sprintf("lines %d-%d", startLine, endLine) +} + +// formatDefinitions formats the definitions into a readable output +func (t *DefinitionsTool) formatDefinitions(definitions []DefinitionResult, filePath, clientName string) string { + var output strings.Builder + + output.WriteString(fmt.Sprintf("# Definitions in %s\n", filePath)) + output.WriteString(fmt.Sprintf("*Using %s LSP server*\n\n", clientName)) + + if len(definitions) == 0 { + output.WriteString("No definitions found in this file.\n") + return output.String() + } + + // Group definitions by kind for better organization + kindGroups := make(map[string][]DefinitionResult) + for _, def := range definitions { + kindGroups[def.Kind] = append(kindGroups[def.Kind], def) + } + + // Define order of kinds for consistent output + kindOrder := []string{"Class", "Interface", "Struct", "Enum", "Function", "Method", "Variable", "Constant", "Property", "Field"} + + // Write definitions grouped by kind + for _, kind := range kindOrder { + if defs, exists := kindGroups[kind]; exists { + output.WriteString(fmt.Sprintf("## %ss\n", kind)) + for _, def := range defs { + t.writeDefinition(&output, def, 0) + } + output.WriteString("\n") + delete(kindGroups, kind) + } + } + + // Write any remaining kinds not in the predefined order + for kind, defs := range kindGroups { + output.WriteString(fmt.Sprintf("## %ss\n", kind)) + for _, def := range defs { + t.writeDefinition(&output, def, 0) + } + output.WriteString("\n") + } + + return output.String() +} + +// writeDefinition recursively writes a definition and its children +func (t *DefinitionsTool) writeDefinition(output *strings.Builder, def DefinitionResult, indent int) { + indentStr := strings.Repeat(" ", indent) + + // Format the main definition line + output.WriteString(fmt.Sprintf("%s- **%s** (%s)", indentStr, def.Name, def.Range)) + + if def.Detail != "" { + output.WriteString(fmt.Sprintf(" - %s", def.Detail)) + } + output.WriteString("\n") + + // Write children with increased indentation + if len(def.Children) > 0 { + for _, child := range def.Children { + t.writeDefinition(output, child, indent+1) + } + } +} + diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index 54bdd4c84ef4a7914e16d994e94ed84158d64f4e..d5ad07ac8170591de7d8d03b414f410e54e5e9d0 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -160,6 +160,7 @@ func init() { registry.register(tools.LSToolName, func() renderer { return lsRenderer{} }) registry.register(tools.SourcegraphToolName, func() renderer { return sourcegraphRenderer{} }) registry.register(tools.DiagnosticsToolName, func() renderer { return diagnosticsRenderer{} }) + registry.register(tools.DefinitionsToolName, func() renderer { return definitionsRenderer{} }) registry.register(agent.AgentToolName, func() renderer { return agentRenderer{} }) } @@ -488,6 +489,31 @@ func (dr diagnosticsRenderer) Render(v *toolCallCmp) string { }) } +// ----------------------------------------------------------------------------- +// Definitions renderer +// ----------------------------------------------------------------------------- + +// definitionsRenderer handles file symbol definitions display +type definitionsRenderer struct { + baseRenderer +} + +// Render displays file definitions with enhanced formatting for symbol structure +func (dr definitionsRenderer) Render(v *toolCallCmp) string { + var params tools.DefinitionsParams + var args []string + if err := dr.unmarshalParams(v.call.Input, ¶ms); err == nil { + file := fsext.PrettyPath(params.FilePath) + args = newParamBuilder().addMain(file).build() + } + + return dr.renderWithParams(v, "Definitions", args, func() string { + // The definitions tool outputs well-structured markdown content + // We'll render it as markdown to preserve the formatting and hierarchy + return renderCodeContent(v, "definitions.md", v.result.Content, 0) + }) +} + // ----------------------------------------------------------------------------- // Task renderer // ----------------------------------------------------------------------------- @@ -747,6 +773,10 @@ func prettifyToolName(name string) string { return "View" case tools.WriteToolName: return "Write" + case tools.DiagnosticsToolName: + return "Diagnostics" + case tools.DefinitionsToolName: + return "Definitions" default: return name }