refactor: mcp tool item (#1923)

Kujtim Hoxha created

Change summary

internal/ui/chat/mcp.go           | 121 +++++++++++++++++++++++++++++
internal/ui/chat/tools.go         |  20 ++-
internal/ui/dialog/permissions.go | 135 ++++++++++++++++++++++++--------
internal/ui/styles/styles.go      |  10 ++
4 files changed, 243 insertions(+), 43 deletions(-)

Detailed changes

internal/ui/chat/mcp.go 🔗

@@ -0,0 +1,121 @@
+package chat
+
+import (
+	"encoding/json"
+	"fmt"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/stringext"
+	"github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+// MCPToolMessageItem is a message item that represents a bash tool call.
+type MCPToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*MCPToolMessageItem)(nil)
+
+// NewMCPToolMessageItem creates a new [MCPToolMessageItem].
+func NewMCPToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &MCPToolRenderContext{}, canceled)
+}
+
+// MCPToolRenderContext renders bash tool messages.
+type MCPToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (b *MCPToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	toolNameParts := strings.SplitN(opts.ToolCall.Name, "_", 3)
+	if len(toolNameParts) != 3 {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid tool name"}, cappedWidth)
+	}
+	mcpName := prettyName(toolNameParts[1])
+	toolName := prettyName(toolNameParts[2])
+
+	mcpName = sty.Tool.MCPName.Render(mcpName)
+	toolName = sty.Tool.MCPToolName.Render(toolName)
+
+	name := fmt.Sprintf("%s %s %s", mcpName, sty.Tool.MCPArrow.String(), toolName)
+
+	if opts.IsPending() {
+		return pendingTool(sty, name, opts.Anim)
+	}
+
+	var params map[string]any
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	var toolParams []string
+	if len(params) > 0 {
+		parsed, _ := json.Marshal(params)
+		toolParams = append(toolParams, string(parsed))
+	}
+
+	header := toolHeader(sty, opts.Status, name, cappedWidth, opts.Compact, toolParams...)
+	if opts.Compact {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if !opts.HasResult() || opts.Result.Content == "" {
+		return header
+	}
+
+	bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
+	// see if the result is json
+	var result json.RawMessage
+	var body string
+	if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil {
+		prettyResult, err := json.MarshalIndent(result, "", "  ")
+		if err == nil {
+			body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent))
+		} else {
+			body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
+		}
+	} else if looksLikeMarkdown(opts.Result.Content) {
+		body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent))
+	} else {
+		body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
+	}
+	return joinToolParts(header, body)
+}
+
+func prettyName(name string) string {
+	name = strings.ReplaceAll(name, "_", " ")
+	name = strings.ReplaceAll(name, "-", " ")
+	return stringext.Capitalize(name)
+}
+
+// looksLikeMarkdown checks if content appears to be markdown by looking for
+// common markdown patterns.
+func looksLikeMarkdown(content string) bool {
+	patterns := []string{
+		"# ",  // headers
+		"## ", // headers
+		"**",  // bold
+		"```", // code fence
+		"- ",  // unordered list
+		"1. ", // ordered list
+		"> ",  // blockquote
+		"---", // horizontal rule
+		"***", // horizontal rule
+	}
+	for _, p := range patterns {
+		if strings.Contains(content, p) {
+			return true
+		}
+	}
+	return false
+}

internal/ui/chat/tools.go 🔗

@@ -243,14 +243,18 @@ func NewToolMessageItem(
 	case tools.TodosToolName:
 		item = NewTodosToolMessageItem(sty, toolCall, result, canceled)
 	default:
-		// TODO: Implement other tool items
-		item = newBaseToolMessageItem(
-			sty,
-			toolCall,
-			result,
-			&DefaultToolRenderContext{},
-			canceled,
-		)
+		if strings.HasPrefix(toolCall.Name, "mcp_") {
+			item = NewMCPToolMessageItem(sty, toolCall, result, canceled)
+		} else {
+			// TODO: Implement other tool items
+			item = newBaseToolMessageItem(
+				sty,
+				toolCall,
+				result,
+				&DefaultToolRenderContext{},
+				canceled,
+			)
+		}
 	}
 	item.SetMessageID(messageID)
 	return item

internal/ui/dialog/permissions.go 🔗

@@ -13,7 +13,9 @@ import (
 	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/permission"
+	"github.com/charmbracelet/crush/internal/stringext"
 	"github.com/charmbracelet/crush/internal/ui/common"
+	"github.com/charmbracelet/crush/internal/ui/styles"
 	uv "github.com/charmbracelet/ultraviolet"
 )
 
@@ -314,19 +316,19 @@ func (p *Permissions) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
 	forceFullscreen := area.Dx() <= minWindowWidth || area.Dy() <= minWindowHeight
 
 	// Calculate dialog dimensions based on fullscreen state and content type.
-	var width, height int
+	var width, maxHeight int
 	if forceFullscreen || (p.fullscreen && p.hasDiffView()) {
 		// Use nearly full window for fullscreen.
 		width = area.Dx()
-		height = area.Dy()
+		maxHeight = area.Dy()
 	} else if p.hasDiffView() {
 		// Wide for side-by-side diffs, capped for readability.
 		width = min(int(float64(area.Dx())*diffSizeRatio), diffMaxWidth)
-		height = int(float64(area.Dy()) * diffSizeRatio)
+		maxHeight = int(float64(area.Dy()) * diffSizeRatio)
 	} else {
 		// Narrower for simple content like commands/URLs.
 		width = min(int(float64(area.Dx())*simpleSizeRatio), simpleMaxWidth)
-		height = int(float64(area.Dy()) * simpleHeightRatio)
+		maxHeight = int(float64(area.Dy()) * simpleHeightRatio)
 	}
 
 	dialogStyle := t.Dialog.View.Width(width).Padding(0, 1)
@@ -341,27 +343,51 @@ func (p *Permissions) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
 	buttonsHeight := lipgloss.Height(buttons)
 	helpHeight := lipgloss.Height(helpView)
 	frameHeight := dialogStyle.GetVerticalFrameSize() + layoutSpacingLines
-	availableHeight := height - headerHeight - buttonsHeight - helpHeight - frameHeight
 
 	p.defaultDiffSplitMode = width >= splitModeMinWidth
 
-	if p.viewport.Width() != contentWidth-1 {
-		// Mark diff content as dirty if width has changed
+	// Pre-render content to measure its actual height.
+	renderedContent := p.renderContent(contentWidth)
+	contentHeight := lipgloss.Height(renderedContent)
+
+	// For non-diff views, shrink dialog to fit content if it's smaller than max.
+	var availableHeight int
+	if !p.hasDiffView() && !forceFullscreen {
+		fixedHeight := headerHeight + buttonsHeight + helpHeight + frameHeight
+		neededHeight := fixedHeight + contentHeight
+		if neededHeight < maxHeight {
+			availableHeight = contentHeight
+		} else {
+			availableHeight = maxHeight - fixedHeight
+		}
+	} else {
+		availableHeight = maxHeight - headerHeight - buttonsHeight - helpHeight - frameHeight
+	}
+
+	// Determine if scrollbar is needed.
+	needsScrollbar := p.hasDiffView() || contentHeight > availableHeight
+	viewportWidth := contentWidth
+	if needsScrollbar {
+		viewportWidth = contentWidth - 1 // Reserve space for scrollbar.
+	}
+
+	if p.viewport.Width() != viewportWidth {
+		// Mark content as dirty if width has changed.
 		p.viewportDirty = true
+		renderedContent = p.renderContent(viewportWidth)
 	}
 
 	var content string
 	var scrollbar string
-	// Non-diff content uses the viewport for scrolling.
-	p.viewport.SetWidth(contentWidth - 1) // -1 for scrollbar
+	p.viewport.SetWidth(viewportWidth)
 	p.viewport.SetHeight(availableHeight)
 	if p.viewportDirty {
-		p.viewport.SetContent(p.renderContent(contentWidth - 1))
+		p.viewport.SetContent(renderedContent)
 		p.viewportWidth = p.viewport.Width()
 		p.viewportDirty = false
 	}
 	content = p.viewport.View()
-	if p.canScroll() {
+	if needsScrollbar {
 		scrollbar = common.Scrollbar(t, availableHeight, p.viewport.TotalLineCount(), availableHeight, p.viewport.YOffset())
 	}
 
@@ -388,7 +414,7 @@ func (p *Permissions) renderHeader(contentWidth int) string {
 	title = t.Dialog.Title.Render(title)
 
 	// Tool info.
-	toolLine := p.renderKeyValue("Tool", p.permission.ToolName, contentWidth)
+	toolLine := p.renderToolName(contentWidth)
 	pathLine := p.renderKeyValue("Path", fsext.PrettyPath(p.permission.Path), contentWidth)
 
 	lines := []string{title, "", toolLine, pathLine}
@@ -439,10 +465,33 @@ func (p *Permissions) renderKeyValue(key, value string, width int) string {
 	return lipgloss.JoinHorizontal(lipgloss.Left, keyStr, valueStr)
 }
 
+func (p *Permissions) renderToolName(width int) string {
+	toolName := p.permission.ToolName
+
+	// Check if this is an MCP tool (format: mcp_<mcpname>_<toolname>).
+	if strings.HasPrefix(toolName, "mcp_") {
+		parts := strings.SplitN(toolName, "_", 3)
+		if len(parts) == 3 {
+			mcpName := prettyName(parts[1])
+			toolPart := prettyName(parts[2])
+			toolName = fmt.Sprintf("%s %s %s", mcpName, styles.ArrowRightIcon, toolPart)
+		}
+	}
+
+	return p.renderKeyValue("Tool", toolName, width)
+}
+
+// prettyName converts snake_case or kebab-case to Title Case.
+func prettyName(name string) string {
+	name = strings.ReplaceAll(name, "_", " ")
+	name = strings.ReplaceAll(name, "-", " ")
+	return stringext.Capitalize(name)
+}
+
 func (p *Permissions) renderContent(width int) string {
 	switch p.permission.ToolName {
 	case tools.BashToolName:
-		return p.renderBashContent()
+		return p.renderBashContent(width)
 	case tools.EditToolName:
 		return p.renderEditContent(width)
 	case tools.WriteToolName:
@@ -450,27 +499,27 @@ func (p *Permissions) renderContent(width int) string {
 	case tools.MultiEditToolName:
 		return p.renderMultiEditContent(width)
 	case tools.DownloadToolName:
-		return p.renderDownloadContent()
+		return p.renderDownloadContent(width)
 	case tools.FetchToolName:
-		return p.renderFetchContent()
+		return p.renderFetchContent(width)
 	case tools.AgenticFetchToolName:
-		return p.renderAgenticFetchContent()
+		return p.renderAgenticFetchContent(width)
 	case tools.ViewToolName:
-		return p.renderViewContent()
+		return p.renderViewContent(width)
 	case tools.LSToolName:
-		return p.renderLSContent()
+		return p.renderLSContent(width)
 	default:
-		return p.renderDefaultContent()
+		return p.renderDefaultContent(width)
 	}
 }
 
-func (p *Permissions) renderBashContent() string {
+func (p *Permissions) renderBashContent(width int) string {
 	params, ok := p.permission.Params.(tools.BashPermissionsParams)
 	if !ok {
 		return ""
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(params.Command)
+	return p.renderContentPanel(params.Command, width)
 }
 
 func (p *Permissions) renderEditContent(contentWidth int) string {
@@ -529,7 +578,7 @@ func (p *Permissions) renderDiff(filePath, oldContent, newContent string, conten
 	return result
 }
 
-func (p *Permissions) renderDownloadContent() string {
+func (p *Permissions) renderDownloadContent(width int) string {
 	params, ok := p.permission.Params.(tools.DownloadPermissionsParams)
 	if !ok {
 		return ""
@@ -540,19 +589,19 @@ func (p *Permissions) renderDownloadContent() string {
 		content += fmt.Sprintf("\nTimeout: %ds", params.Timeout)
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(content)
+	return p.renderContentPanel(content, width)
 }
 
-func (p *Permissions) renderFetchContent() string {
+func (p *Permissions) renderFetchContent(width int) string {
 	params, ok := p.permission.Params.(tools.FetchPermissionsParams)
 	if !ok {
 		return ""
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(params.URL)
+	return p.renderContentPanel(params.URL, width)
 }
 
-func (p *Permissions) renderAgenticFetchContent() string {
+func (p *Permissions) renderAgenticFetchContent(width int) string {
 	params, ok := p.permission.Params.(tools.AgenticFetchPermissionsParams)
 	if !ok {
 		return ""
@@ -565,10 +614,10 @@ func (p *Permissions) renderAgenticFetchContent() string {
 		content = fmt.Sprintf("Prompt: %s", params.Prompt)
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(content)
+	return p.renderContentPanel(content, width)
 }
 
-func (p *Permissions) renderViewContent() string {
+func (p *Permissions) renderViewContent(width int) string {
 	params, ok := p.permission.Params.(tools.ViewPermissionsParams)
 	if !ok {
 		return ""
@@ -582,10 +631,10 @@ func (p *Permissions) renderViewContent() string {
 		content += fmt.Sprintf("\nLines to read: %d", params.Limit)
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(content)
+	return p.renderContentPanel(content, width)
 }
 
-func (p *Permissions) renderLSContent() string {
+func (p *Permissions) renderLSContent(width int) string {
 	params, ok := p.permission.Params.(tools.LSPermissionsParams)
 	if !ok {
 		return ""
@@ -596,11 +645,16 @@ func (p *Permissions) renderLSContent() string {
 		content += fmt.Sprintf("\nIgnore patterns: %s", strings.Join(params.Ignore, ", "))
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(content)
+	return p.renderContentPanel(content, width)
 }
 
-func (p *Permissions) renderDefaultContent() string {
-	content := p.permission.Description
+func (p *Permissions) renderDefaultContent(width int) string {
+	t := p.com.Styles
+	var content string
+	// do not add the description for mcp tools
+	if !strings.HasPrefix(p.permission.ToolName, "mcp_") {
+		content = p.permission.Description
+	}
 
 	// Pretty-print JSON params if available.
 	if p.permission.Params != nil {
@@ -614,10 +668,15 @@ func (p *Permissions) renderDefaultContent() string {
 		var parsed any
 		if err := json.Unmarshal([]byte(paramStr), &parsed); err == nil {
 			if b, err := json.MarshalIndent(parsed, "", "  "); err == nil {
+				jsonContent := string(b)
+				highlighted, err := common.SyntaxHighlight(t, jsonContent, "params.json", t.BgSubtle)
+				if err == nil {
+					jsonContent = highlighted
+				}
 				if content != "" {
 					content += "\n\n"
 				}
-				content += string(b)
+				content += jsonContent
 			}
 		} else if paramStr != "" {
 			if content != "" {
@@ -631,7 +690,13 @@ func (p *Permissions) renderDefaultContent() string {
 		return ""
 	}
 
-	return p.com.Styles.Dialog.ContentPanel.Render(strings.TrimSpace(content))
+	return p.renderContentPanel(strings.TrimSpace(content), width)
+}
+
+// renderContentPanel renders content in a panel with the full width.
+func (p *Permissions) renderContentPanel(content string, width int) string {
+	panelStyle := p.com.Styles.Dialog.ContentPanel
+	return panelStyle.Width(width).Render(content)
 }
 
 func (p *Permissions) renderButtons(contentWidth int) string {

internal/ui/styles/styles.go 🔗

@@ -310,6 +310,11 @@ type Styles struct {
 		TodoCompletedIcon  lipgloss.Style // Completed todo icon
 		TodoInProgressIcon lipgloss.Style // In-progress todo icon
 		TodoPendingIcon    lipgloss.Style // Pending todo icon
+
+		// MCP tools
+		MCPName     lipgloss.Style // The mcp name
+		MCPToolName lipgloss.Style // The mcp tool name
+		MCPArrow    lipgloss.Style // The mcp arrow icon
 	}
 
 	// Dialog styles
@@ -1130,6 +1135,11 @@ func DefaultStyles() Styles {
 	s.Tool.TodoInProgressIcon = base.Foreground(greenDark)
 	s.Tool.TodoPendingIcon = base.Foreground(fgMuted)
 
+	// MCP styles
+	s.Tool.MCPName = base.Foreground(blue)
+	s.Tool.MCPToolName = base.Foreground(blueDark)
+	s.Tool.MCPArrow = base.Foreground(blue).SetString(ArrowRightIcon)
+
 	// Buttons
 	s.ButtonFocus = lipgloss.NewStyle().Foreground(white).Background(secondary)
 	s.ButtonBlur = s.Base.Background(bgSubtle)