feat: add definitions tool

Kujtim Hoxha created

Change summary

internal/config/config.go                         |   5 
internal/llm/agent/agent.go                       |   3 
internal/llm/tools/definitions.go                 | 433 +++++++++++++++++
internal/tui/components/chat/messages/renderer.go |  30 +
4 files changed, 468 insertions(+), 3 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -324,6 +324,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
 				"ls",
 				"sourcegraph",
 				"view",
+				"definitions",
 			},
 			// NO MCPs or LSPs by default
 			AllowedMCP: map[string][]string{},
@@ -1264,7 +1265,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) {
 	}
 
 	validTools := []string{
-		"bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
+		"bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", "definitions", "diagnostics",
 	}
 
 	for agentID, agent := range c.Agents {
@@ -1342,7 +1343,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) {
 			if agent.Description != "An agent that helps with searching for context and finding implementation details." {
 				errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
 			}
-			expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
+			expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view", "definitions"}
 			if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
 				errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
 			}

internal/llm/agent/agent.go 🔗

@@ -94,8 +94,9 @@ func NewAgent(
 	ctx := context.Background()
 	cfg := config.Get()
 	otherTools := GetMcpTools(ctx, permissions)
-	if len(lspClients) > 0 {
+	if len(cfg.LSP) > 0 {
 		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+		otherTools = append(otherTools, tools.NewDefinitionsTool(lspClients))
 	}
 
 	allTools := []tools.BaseTool{

internal/llm/tools/definitions.go 🔗

@@ -0,0 +1,433 @@
+package tools
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/lsp"
+	"github.com/charmbracelet/crush/internal/lsp/protocol"
+)
+
+type DefinitionsTool struct {
+	lspClients map[string]*lsp.Client
+}
+
+const (
+	DefinitionsToolName = "definitions"
+)
+
+func NewDefinitionsTool(lspClients map[string]*lsp.Client) BaseTool {
+	return &DefinitionsTool{
+		lspClients: lspClients,
+	}
+}
+
+func (t *DefinitionsTool) Name() string {
+	return DefinitionsToolName
+}
+
+func (t *DefinitionsTool) Info() ToolInfo {
+	return ToolInfo{
+		Name: DefinitionsToolName,
+		Description: `Gets all symbol definitions from a file using the appropriate LSP server.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to understand the structure and symbols in a code file
+- Helpful for exploring unfamiliar codebases and understanding what's defined in a file
+- Good for finding functions, classes, variables, types, interfaces, and other symbols
+- Useful before making changes to understand existing code structure
+
+HOW TO USE:
+- Provide the path to a source code file
+- The tool will automatically select the appropriate LSP server based on file extension
+- Results show symbol names, types, locations, and hierarchical relationships
+
+FEATURES:
+- Supports multiple programming languages (Go, TypeScript, JavaScript, Rust, Python, etc.)
+- Shows hierarchical symbol relationships (classes with their methods, etc.)
+- Provides precise line number locations for each symbol
+- Includes symbol details and documentation when available
+
+LIMITATIONS:
+- Requires an active LSP server for the file's language
+- Only works with files that the LSP server can parse
+- Results depend on LSP server capabilities and file syntax correctness
+
+TIPS:
+- Use this tool to get an overview of a file's structure before editing
+- Combine with other tools like View to see the actual code implementation
+- Helpful for understanding large files with many symbols`,
+		Parameters: map[string]any{
+			"file_path": map[string]any{
+				"type":        "string",
+				"description": "The path to the file to get definitions from",
+			},
+		},
+		Required: []string{"file_path"},
+	}
+}
+
+type DefinitionsParams struct {
+	FilePath string `json:"file_path"`
+}
+
+type DefinitionResult struct {
+	Name     string             `json:"name"`
+	Kind     string             `json:"kind"`
+	Detail   string             `json:"detail,omitempty"`
+	Range    string             `json:"range"`
+	Children []DefinitionResult `json:"children,omitempty"`
+}
+
+func (t *DefinitionsTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
+	var definitionsParams DefinitionsParams
+	if err := json.Unmarshal([]byte(params.Input), &definitionsParams); err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("Invalid parameters: %v", err)), nil
+	}
+
+	if definitionsParams.FilePath == "" {
+		return NewTextErrorResponse("file_path parameter is required"), nil
+	}
+
+	// Check if file exists
+	if _, err := os.Stat(definitionsParams.FilePath); err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("File does not exist: %s", definitionsParams.FilePath)), nil
+	}
+
+	// Find the appropriate LSP client for this file
+	client, clientName, err := t.findLSPClient(definitionsParams.FilePath)
+	if err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("No suitable LSP client found for file %s: %v\n\nAvailable LSP clients: %s",
+			definitionsParams.FilePath, err, t.getAvailableClients())), nil
+	}
+
+	// Check if the server is ready
+	if client.GetServerState() != lsp.StateReady {
+		return NewTextErrorResponse(fmt.Sprintf("LSP server %s is not ready (state: %v). Please wait for the server to initialize or check the server configuration.",
+			clientName, client.GetServerState())), nil
+	}
+
+	// Ensure the file is open in the LSP
+	if err := client.OpenFileOnDemand(ctx, definitionsParams.FilePath); err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("Failed to open file in LSP: %v", err)), nil
+	}
+
+	// Get document symbols
+	documentURI := protocol.URIFromPath(definitionsParams.FilePath)
+	symbolParams := protocol.DocumentSymbolParams{
+		TextDocument: protocol.TextDocumentIdentifier{
+			URI: documentURI,
+		},
+	}
+
+	result, err := client.DocumentSymbol(ctx, symbolParams)
+	if err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("Failed to get document symbols from %s LSP: %v\n\nThis might happen if:\n- The file has syntax errors\n- The LSP server doesn't support this file type\n- The file is not part of a recognized project structure",
+			clientName, err)), nil
+	}
+
+	// Parse the result and format it
+	definitions, err := t.parseSymbolResult(result)
+	if err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("Failed to parse symbols: %v", err)), nil
+	}
+
+	if len(definitions) == 0 {
+		return NewTextResponse(fmt.Sprintf("No definitions found in file %s.\n\nThis might happen if:\n- The file is empty or contains only comments\n- The file has syntax errors preventing symbol extraction\n- The LSP server doesn't recognize symbols in this file type",
+			definitionsParams.FilePath)), nil
+	}
+
+	// Format the output
+	output := t.formatDefinitions(definitions, definitionsParams.FilePath, clientName)
+	return NewTextResponse(output), nil
+}
+
+// findLSPClient finds the most appropriate LSP client for the given file
+func (t *DefinitionsTool) findLSPClient(filePath string) (*lsp.Client, string, error) {
+	ext := strings.ToLower(filepath.Ext(filePath))
+
+	// Map file extensions to preferred LSP client types
+	preferredClients := map[string][]string{
+		".go":   {"gopls", "go"},
+		".ts":   {"typescript", "vtsls", "tsserver"},
+		".tsx":  {"typescript", "vtsls", "tsserver"},
+		".js":   {"typescript", "vtsls", "tsserver"},
+		".jsx":  {"typescript", "vtsls", "tsserver"},
+		".rs":   {"rust-analyzer", "rust"},
+		".py":   {"pyright", "pylsp", "python"},
+		".java": {"jdtls", "java"},
+		".c":    {"clangd", "ccls", "c"},
+		".cpp":  {"clangd", "ccls", "cpp", "c++"},
+		".cc":   {"clangd", "ccls", "cpp", "c++"},
+		".cxx":  {"clangd", "ccls", "cpp", "c++"},
+		".h":    {"clangd", "ccls", "c", "cpp"},
+		".hpp":  {"clangd", "ccls", "cpp", "c++"},
+		".cs":   {"omnisharp", "csharp"},
+		".php":  {"intelephense", "php"},
+		".rb":   {"solargraph", "ruby"},
+		".lua":  {"lua-language-server", "lua"},
+		".sh":   {"bash-language-server", "bash"},
+		".bash": {"bash-language-server", "bash"},
+		".zsh":  {"bash-language-server", "bash"},
+	}
+
+	// First, try to find a client that matches the preferred types for this file extension
+	if preferred, exists := preferredClients[ext]; exists {
+		for _, clientType := range preferred {
+			for name, client := range t.lspClients {
+				if strings.Contains(strings.ToLower(name), clientType) && client.GetServerState() == lsp.StateReady {
+					return client, name, nil
+				}
+			}
+		}
+	}
+
+	// If no preferred client found, try any available ready client
+	// This is a fallback for generic LSP servers that might support multiple languages
+	for name, client := range t.lspClients {
+		if client.GetServerState() == lsp.StateReady {
+			return client, name, nil
+		}
+	}
+
+	return nil, "", fmt.Errorf("no suitable LSP client found for file extension %s", ext)
+}
+
+// getAvailableClients returns a string listing all available LSP clients
+func (t *DefinitionsTool) getAvailableClients() string {
+	if len(t.lspClients) == 0 {
+		return "none"
+	}
+
+	var clients []string
+	for name := range t.lspClients {
+		clients = append(clients, name)
+	}
+	return strings.Join(clients, ", ")
+}
+
+// parseSymbolResult parses the LSP symbol result into our format
+func (t *DefinitionsTool) parseSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) ([]DefinitionResult, error) {
+	var definitions []DefinitionResult
+
+	// The result can be either []DocumentSymbol or []SymbolInformation
+	// Try to unmarshal as DocumentSymbol first (newer format with hierarchy)
+	if result.Value != nil {
+		// Convert interface{} to []byte for unmarshaling
+		resultBytes, err := json.Marshal(result.Value)
+		if err != nil {
+			return definitions, fmt.Errorf("failed to marshal result: %v", err)
+		}
+
+		// Try DocumentSymbol format first (hierarchical)
+		var docSymbols []protocol.DocumentSymbol
+		if err := json.Unmarshal(resultBytes, &docSymbols); err == nil && len(docSymbols) > 0 {
+			for _, symbol := range docSymbols {
+				definitions = append(definitions, t.convertDocumentSymbol(symbol))
+			}
+			return definitions, nil
+		}
+
+		// Try SymbolInformation format (flat list)
+		var symbolInfos []protocol.SymbolInformation
+		if err := json.Unmarshal(resultBytes, &symbolInfos); err == nil && len(symbolInfos) > 0 {
+			for _, symbol := range symbolInfos {
+				definitions = append(definitions, t.convertSymbolInformation(symbol))
+			}
+			return definitions, nil
+		}
+
+		// If both fail, try to handle as a single symbol (some servers return single objects)
+		var singleDocSymbol protocol.DocumentSymbol
+		if err := json.Unmarshal(resultBytes, &singleDocSymbol); err == nil && singleDocSymbol.Name != "" {
+			definitions = append(definitions, t.convertDocumentSymbol(singleDocSymbol))
+			return definitions, nil
+		}
+
+		var singleSymbolInfo protocol.SymbolInformation
+		if err := json.Unmarshal(resultBytes, &singleSymbolInfo); err == nil && singleSymbolInfo.Name != "" {
+			definitions = append(definitions, t.convertSymbolInformation(singleSymbolInfo))
+			return definitions, nil
+		}
+	}
+
+	return definitions, nil
+}
+
+// convertDocumentSymbol converts a DocumentSymbol to our format
+func (t *DefinitionsTool) convertDocumentSymbol(symbol protocol.DocumentSymbol) DefinitionResult {
+	result := DefinitionResult{
+		Name:   symbol.Name,
+		Kind:   t.symbolKindToString(symbol.Kind),
+		Detail: symbol.Detail,
+		Range:  t.formatRange(symbol.Range),
+	}
+
+	// Convert children recursively
+	for _, child := range symbol.Children {
+		result.Children = append(result.Children, t.convertDocumentSymbol(child))
+	}
+
+	return result
+}
+
+// convertSymbolInformation converts a SymbolInformation to our format
+func (t *DefinitionsTool) convertSymbolInformation(symbol protocol.SymbolInformation) DefinitionResult {
+	return DefinitionResult{
+		Name:  symbol.Name,
+		Kind:  t.symbolKindToString(symbol.Kind),
+		Range: t.formatLocation(symbol.Location),
+	}
+}
+
+// symbolKindToString converts SymbolKind to a readable string
+func (t *DefinitionsTool) symbolKindToString(kind protocol.SymbolKind) string {
+	switch kind {
+	case protocol.File:
+		return "File"
+	case protocol.Module:
+		return "Module"
+	case protocol.Namespace:
+		return "Namespace"
+	case protocol.Package:
+		return "Package"
+	case protocol.Class:
+		return "Class"
+	case protocol.Method:
+		return "Method"
+	case protocol.Property:
+		return "Property"
+	case protocol.Field:
+		return "Field"
+	case protocol.Constructor:
+		return "Constructor"
+	case protocol.Enum:
+		return "Enum"
+	case protocol.Interface:
+		return "Interface"
+	case protocol.Function:
+		return "Function"
+	case protocol.Variable:
+		return "Variable"
+	case protocol.Constant:
+		return "Constant"
+	case protocol.String:
+		return "String"
+	case protocol.Number:
+		return "Number"
+	case protocol.Boolean:
+		return "Boolean"
+	case protocol.Array:
+		return "Array"
+	case protocol.Object:
+		return "Object"
+	case protocol.Key:
+		return "Key"
+	case protocol.Null:
+		return "Null"
+	case protocol.EnumMember:
+		return "EnumMember"
+	case protocol.Struct:
+		return "Struct"
+	case protocol.Event:
+		return "Event"
+	case protocol.Operator:
+		return "Operator"
+	case protocol.TypeParameter:
+		return "TypeParameter"
+	default:
+		return fmt.Sprintf("Unknown(%d)", kind)
+	}
+}
+
+// formatRange formats a Range to a readable string
+func (t *DefinitionsTool) formatRange(r protocol.Range) string {
+	startLine := r.Start.Line + 1
+	endLine := r.End.Line + 1
+	
+	if startLine == endLine {
+		return fmt.Sprintf("line %d", startLine)
+	}
+	return fmt.Sprintf("lines %d-%d", startLine, endLine)
+}
+
+// formatLocation formats a Location to a readable string
+func (t *DefinitionsTool) formatLocation(loc protocol.Location) string {
+	startLine := loc.Range.Start.Line + 1
+	endLine := loc.Range.End.Line + 1
+	
+	if startLine == endLine {
+		return fmt.Sprintf("line %d", startLine)
+	}
+	return fmt.Sprintf("lines %d-%d", startLine, endLine)
+}
+
+// formatDefinitions formats the definitions into a readable output
+func (t *DefinitionsTool) formatDefinitions(definitions []DefinitionResult, filePath, clientName string) string {
+	var output strings.Builder
+
+	output.WriteString(fmt.Sprintf("# Definitions in %s\n", filePath))
+	output.WriteString(fmt.Sprintf("*Using %s LSP server*\n\n", clientName))
+
+	if len(definitions) == 0 {
+		output.WriteString("No definitions found in this file.\n")
+		return output.String()
+	}
+
+	// Group definitions by kind for better organization
+	kindGroups := make(map[string][]DefinitionResult)
+	for _, def := range definitions {
+		kindGroups[def.Kind] = append(kindGroups[def.Kind], def)
+	}
+
+	// Define order of kinds for consistent output
+	kindOrder := []string{"Class", "Interface", "Struct", "Enum", "Function", "Method", "Variable", "Constant", "Property", "Field"}
+
+	// Write definitions grouped by kind
+	for _, kind := range kindOrder {
+		if defs, exists := kindGroups[kind]; exists {
+			output.WriteString(fmt.Sprintf("## %ss\n", kind))
+			for _, def := range defs {
+				t.writeDefinition(&output, def, 0)
+			}
+			output.WriteString("\n")
+			delete(kindGroups, kind)
+		}
+	}
+
+	// Write any remaining kinds not in the predefined order
+	for kind, defs := range kindGroups {
+		output.WriteString(fmt.Sprintf("## %ss\n", kind))
+		for _, def := range defs {
+			t.writeDefinition(&output, def, 0)
+		}
+		output.WriteString("\n")
+	}
+
+	return output.String()
+}
+
+// writeDefinition recursively writes a definition and its children
+func (t *DefinitionsTool) writeDefinition(output *strings.Builder, def DefinitionResult, indent int) {
+	indentStr := strings.Repeat("  ", indent)
+
+	// Format the main definition line
+	output.WriteString(fmt.Sprintf("%s- **%s** (%s)", indentStr, def.Name, def.Range))
+
+	if def.Detail != "" {
+		output.WriteString(fmt.Sprintf(" - %s", def.Detail))
+	}
+	output.WriteString("\n")
+
+	// Write children with increased indentation
+	if len(def.Children) > 0 {
+		for _, child := range def.Children {
+			t.writeDefinition(output, child, indent+1)
+		}
+	}
+}
+

internal/tui/components/chat/messages/renderer.go 🔗

@@ -160,6 +160,7 @@ func init() {
 	registry.register(tools.LSToolName, func() renderer { return lsRenderer{} })
 	registry.register(tools.SourcegraphToolName, func() renderer { return sourcegraphRenderer{} })
 	registry.register(tools.DiagnosticsToolName, func() renderer { return diagnosticsRenderer{} })
+	registry.register(tools.DefinitionsToolName, func() renderer { return definitionsRenderer{} })
 	registry.register(agent.AgentToolName, func() renderer { return agentRenderer{} })
 }
 
@@ -488,6 +489,31 @@ func (dr diagnosticsRenderer) Render(v *toolCallCmp) string {
 	})
 }
 
+// -----------------------------------------------------------------------------
+//  Definitions renderer
+// -----------------------------------------------------------------------------
+
+// definitionsRenderer handles file symbol definitions display
+type definitionsRenderer struct {
+	baseRenderer
+}
+
+// Render displays file definitions with enhanced formatting for symbol structure
+func (dr definitionsRenderer) Render(v *toolCallCmp) string {
+	var params tools.DefinitionsParams
+	var args []string
+	if err := dr.unmarshalParams(v.call.Input, &params); err == nil {
+		file := fsext.PrettyPath(params.FilePath)
+		args = newParamBuilder().addMain(file).build()
+	}
+
+	return dr.renderWithParams(v, "Definitions", args, func() string {
+		// The definitions tool outputs well-structured markdown content
+		// We'll render it as markdown to preserve the formatting and hierarchy
+		return renderCodeContent(v, "definitions.md", v.result.Content, 0)
+	})
+}
+
 // -----------------------------------------------------------------------------
 //  Task renderer
 // -----------------------------------------------------------------------------
@@ -747,6 +773,10 @@ func prettifyToolName(name string) string {
 		return "View"
 	case tools.WriteToolName:
 		return "Write"
+	case tools.DiagnosticsToolName:
+		return "Diagnostics"
+	case tools.DefinitionsToolName:
+		return "Definitions"
 	default:
 		return name
 	}