feat(mcp): refactor, support prompts

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/agent/coordinator.go                         |  15 
internal/agent/tools/mcp-tools.go                     | 525 -------------
internal/agent/tools/mcp/init.go                      | 400 +++++++++
internal/agent/tools/mcp/prompts.go                   |  83 ++
internal/agent/tools/mcp/tools.go                     | 169 ++++
internal/app/app.go                                   |  11 
internal/tui/components/dialogs/commands/arguments.go | 162 ++-
internal/tui/components/dialogs/commands/commands.go  | 104 +
internal/tui/components/dialogs/commands/loader.go    | 101 ++
internal/tui/components/mcp/mcp.go                    |  24 
internal/tui/tui.go                                   |  35 
11 files changed, 972 insertions(+), 657 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -18,6 +18,7 @@ import (
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/agent/prompt"
 	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/history"
@@ -344,25 +345,23 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
 		}
 	}
 
-	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
-
-	for _, mcpTool := range mcpTools {
+	for tool := range mcp.GetMCPTools() {
 		if agent.AllowedMCP == nil {
 			// No MCP restrictions
-			filteredTools = append(filteredTools, mcpTool)
+			filteredTools = append(filteredTools, tool)
 		} else if len(agent.AllowedMCP) == 0 {
 			// no mcps allowed
 			break
 		}
 
 		for mcp, tools := range agent.AllowedMCP {
-			if mcp == mcpTool.MCP() {
+			if mcp == tool.MCP() {
 				if len(tools) == 0 {
-					filteredTools = append(filteredTools, mcpTool)
+					filteredTools = append(filteredTools, tool)
 				}
 				for _, t := range tools {
-					if t == mcpTool.MCPToolName() {
-						filteredTools = append(filteredTools, mcpTool)
+					if t == tool.MCPToolName() {
+						filteredTools = append(filteredTools, tool)
 					}
 				}
 				break

internal/agent/tools/mcp-tools.go 🔗

@@ -1,525 +0,0 @@
-package tools
-
-import (
-	"cmp"
-	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"log/slog"
-	"maps"
-	"net/http"
-	"os"
-	"os/exec"
-	"slices"
-	"strings"
-	"sync"
-	"time"
-
-	"charm.land/fantasy"
-	"github.com/charmbracelet/crush/internal/config"
-	"github.com/charmbracelet/crush/internal/csync"
-	"github.com/charmbracelet/crush/internal/home"
-	"github.com/charmbracelet/crush/internal/permission"
-	"github.com/charmbracelet/crush/internal/pubsub"
-	"github.com/charmbracelet/crush/internal/version"
-	"github.com/modelcontextprotocol/go-sdk/mcp"
-)
-
-// MCPState represents the current state of an MCP client
-type MCPState int
-
-const (
-	MCPStateDisabled MCPState = iota
-	MCPStateStarting
-	MCPStateConnected
-	MCPStateError
-)
-
-func (s MCPState) String() string {
-	switch s {
-	case MCPStateDisabled:
-		return "disabled"
-	case MCPStateStarting:
-		return "starting"
-	case MCPStateConnected:
-		return "connected"
-	case MCPStateError:
-		return "error"
-	default:
-		return "unknown"
-	}
-}
-
-// MCPEventType represents the type of MCP event
-type MCPEventType string
-
-const (
-	MCPEventStateChanged     MCPEventType = "state_changed"
-	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
-)
-
-// MCPEvent represents an event in the MCP system
-type MCPEvent struct {
-	Type      MCPEventType
-	Name      string
-	State     MCPState
-	Error     error
-	ToolCount int
-}
-
-// MCPClientInfo holds information about an MCP client's state
-type MCPClientInfo struct {
-	Name        string
-	State       MCPState
-	Error       error
-	Client      *mcp.ClientSession
-	ToolCount   int
-	ConnectedAt time.Time
-}
-
-var (
-	mcpToolsOnce    sync.Once
-	mcpTools        = csync.NewMap[string, *McpTool]()
-	mcpClient2Tools = csync.NewMap[string, []*McpTool]()
-	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
-	mcpStates       = csync.NewMap[string, MCPClientInfo]()
-	mcpBroker       = pubsub.NewBroker[MCPEvent]()
-)
-
-type McpTool struct {
-	mcpName         string
-	tool            *mcp.Tool
-	permissions     permission.Service
-	workingDir      string
-	providerOptions fantasy.ProviderOptions
-}
-
-func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) {
-	m.providerOptions = opts
-}
-
-func (m *McpTool) ProviderOptions() fantasy.ProviderOptions {
-	return m.providerOptions
-}
-
-func (m *McpTool) Name() string {
-	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
-}
-
-func (m *McpTool) MCP() string {
-	return m.mcpName
-}
-
-func (m *McpTool) MCPToolName() string {
-	return m.tool.Name
-}
-
-func (b *McpTool) Info() fantasy.ToolInfo {
-	parameters := make(map[string]any)
-	required := make([]string, 0)
-
-	if input, ok := b.tool.InputSchema.(map[string]any); ok {
-		if props, ok := input["properties"].(map[string]any); ok {
-			parameters = props
-		}
-		if req, ok := input["required"].([]any); ok {
-			// Convert []any -> []string when elements are strings
-			for _, v := range req {
-				if s, ok := v.(string); ok {
-					required = append(required, s)
-				}
-			}
-		} else if reqStr, ok := input["required"].([]string); ok {
-			// Handle case where it's already []string
-			required = reqStr
-		}
-	}
-
-	return fantasy.ToolInfo{
-		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
-		Description: b.tool.Description,
-		Parameters:  parameters,
-		Required:    required,
-	}
-}
-
-func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
-	var args map[string]any
-	if err := json.Unmarshal([]byte(input), &args); err != nil {
-		return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
-	}
-
-	c, err := getOrRenewClient(ctx, name)
-	if err != nil {
-		return fantasy.NewTextErrorResponse(err.Error()), nil
-	}
-	result, err := c.CallTool(ctx, &mcp.CallToolParams{
-		Name:      toolName,
-		Arguments: args,
-	})
-	if err != nil {
-		return fantasy.NewTextErrorResponse(err.Error()), nil
-	}
-
-	output := make([]string, 0, len(result.Content))
-	for _, v := range result.Content {
-		if vv, ok := v.(*mcp.TextContent); ok {
-			output = append(output, vv.Text)
-		} else {
-			output = append(output, fmt.Sprintf("%v", v))
-		}
-	}
-	return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
-}
-
-func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
-	sess, ok := mcpClients.Get(name)
-	if !ok {
-		return nil, fmt.Errorf("mcp '%s' not available", name)
-	}
-
-	cfg := config.Get()
-	m := cfg.MCP[name]
-	state, _ := mcpStates.Get(name)
-
-	timeout := mcpTimeout(m)
-	pingCtx, cancel := context.WithTimeout(ctx, timeout)
-	defer cancel()
-	err := sess.Ping(pingCtx, nil)
-	if err == nil {
-		return sess, nil
-	}
-	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
-
-	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
-	if err != nil {
-		return nil, err
-	}
-
-	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
-	mcpClients.Set(name, sess)
-	return sess, nil
-}
-
-func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
-	sessionID := GetSessionFromContext(ctx)
-	if sessionID == "" {
-		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
-	}
-	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
-	p := m.permissions.Request(
-		permission.CreatePermissionRequest{
-			SessionID:   sessionID,
-			ToolCallID:  params.ID,
-			Path:        m.workingDir,
-			ToolName:    m.Info().Name,
-			Action:      "execute",
-			Description: permissionDescription,
-			Params:      params.Input,
-		},
-	)
-	if !p {
-		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
-	}
-
-	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
-}
-
-func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
-	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
-	if err != nil {
-		return nil, err
-	}
-	mcpTools := make([]*McpTool, 0, len(result.Tools))
-	for _, tool := range result.Tools {
-		mcpTools = append(mcpTools, &McpTool{
-			mcpName:     name,
-			tool:        tool,
-			permissions: permissions,
-			workingDir:  workingDir,
-		})
-	}
-	return mcpTools, nil
-}
-
-// SubscribeMCPEvents returns a channel for MCP events
-func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
-	return mcpBroker.Subscribe(ctx)
-}
-
-// GetMCPStates returns the current state of all MCP clients
-func GetMCPStates() map[string]MCPClientInfo {
-	return maps.Collect(mcpStates.Seq2())
-}
-
-// GetMCPState returns the state of a specific MCP client
-func GetMCPState(name string) (MCPClientInfo, bool) {
-	return mcpStates.Get(name)
-}
-
-// updateMCPState updates the state of an MCP client and publishes an event
-func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
-	info := MCPClientInfo{
-		Name:      name,
-		State:     state,
-		Error:     err,
-		Client:    client,
-		ToolCount: toolCount,
-	}
-	switch state {
-	case MCPStateConnected:
-		info.ConnectedAt = time.Now()
-	case MCPStateError:
-		updateMcpTools(name, nil)
-		mcpClients.Del(name)
-	}
-	mcpStates.Set(name, info)
-
-	// Publish state change event
-	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
-		Type:      MCPEventStateChanged,
-		Name:      name,
-		State:     state,
-		Error:     err,
-		ToolCount: toolCount,
-	})
-}
-
-// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
-func CloseMCPClients() error {
-	var errs []error
-	for name, c := range mcpClients.Seq2() {
-		if err := c.Close(); err != nil &&
-			!errors.Is(err, io.EOF) &&
-			!errors.Is(err, context.Canceled) &&
-			err.Error() != "signal: killed" {
-			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
-		}
-	}
-	mcpBroker.Shutdown()
-	return errors.Join(errs...)
-}
-
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
-	mcpToolsOnce.Do(func() {
-		var wg sync.WaitGroup
-		// Initialize states for all configured MCPs
-		for name, m := range cfg.MCP {
-			if m.Disabled {
-				updateMCPState(name, MCPStateDisabled, nil, nil, 0)
-				slog.Debug("skipping disabled mcp", "name", name)
-				continue
-			}
-
-			// Set initial starting state
-			updateMCPState(name, MCPStateStarting, nil, nil, 0)
-
-			wg.Add(1)
-			go func(name string, m config.MCPConfig) {
-				defer func() {
-					wg.Done()
-					if r := recover(); r != nil {
-						var err error
-						switch v := r.(type) {
-						case error:
-							err = v
-						case string:
-							err = fmt.Errorf("panic: %s", v)
-						default:
-							err = fmt.Errorf("panic: %v", v)
-						}
-						updateMCPState(name, MCPStateError, err, nil, 0)
-						slog.Error("panic in mcp client initialization", "error", err, "name", name)
-					}
-				}()
-
-				ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
-				defer cancel()
-
-				c, err := createMCPSession(ctx, name, m, cfg.Resolver())
-				if err != nil {
-					return
-				}
-
-				mcpClients.Set(name, c)
-
-				tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
-				if err != nil {
-					slog.Error("error listing tools", "error", err)
-					updateMCPState(name, MCPStateError, err, nil, 0)
-					c.Close()
-					return
-				}
-
-				updateMcpTools(name, tools)
-				mcpClients.Set(name, c)
-				updateMCPState(name, MCPStateConnected, nil, c, len(tools))
-			}(name, m)
-		}
-		wg.Wait()
-	})
-	return slices.Collect(mcpTools.Seq())
-}
-
-// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
-func updateMcpTools(mcpName string, tools []*McpTool) {
-	if len(tools) == 0 {
-		mcpClient2Tools.Del(mcpName)
-	} else {
-		mcpClient2Tools.Set(mcpName, tools)
-	}
-	for _, tools := range mcpClient2Tools.Seq2() {
-		for _, t := range tools {
-			mcpTools.Set(t.Name(), t)
-		}
-	}
-}
-
-func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
-	timeout := mcpTimeout(m)
-	mcpCtx, cancel := context.WithCancel(ctx)
-	cancelTimer := time.AfterFunc(timeout, cancel)
-
-	transport, err := createMCPTransport(mcpCtx, m, resolver)
-	if err != nil {
-		updateMCPState(name, MCPStateError, err, nil, 0)
-		slog.Error("error creating mcp client", "error", err, "name", name)
-		cancel()
-		cancelTimer.Stop()
-		return nil, err
-	}
-
-	client := mcp.NewClient(
-		&mcp.Implementation{
-			Name:    "crush",
-			Version: version.Version,
-			Title:   "Crush",
-		},
-		&mcp.ClientOptions{
-			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
-				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
-					Type: MCPEventToolsListChanged,
-					Name: name,
-				})
-			},
-			KeepAlive: time.Minute * 10,
-		},
-	)
-
-	session, err := client.Connect(mcpCtx, transport, nil)
-	if err != nil {
-		err = maybeStdioErr(err, transport)
-		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
-		slog.Error("error starting mcp client", "error", err, "name", name)
-		cancel()
-		cancelTimer.Stop()
-		return nil, err
-	}
-
-	cancelTimer.Stop()
-	slog.Info("Initialized mcp client", "name", name)
-	return session, nil
-}
-
-// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
-// to parse, and the cli will then close it, causing the EOF error.
-// so, if we got an EOF err, and the transport is STDIO, we try to exec it
-// again with a timeout and collect the output so we can add details to the
-// error.
-// this happens particularly when starting things with npx, e.g. if node can't
-// be found or some other error like that.
-func maybeStdioErr(err error, transport mcp.Transport) error {
-	if !errors.Is(err, io.EOF) {
-		return err
-	}
-	ct, ok := transport.(*mcp.CommandTransport)
-	if !ok {
-		return err
-	}
-	if err2 := stdioMCPCheck(ct.Command); err2 != nil {
-		err = errors.Join(err, err2)
-	}
-	return err
-}
-
-func maybeTimeoutErr(err error, timeout time.Duration) error {
-	if errors.Is(err, context.Canceled) {
-		return fmt.Errorf("timed out after %s", timeout)
-	}
-	return err
-}
-
-func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
-	switch m.Type {
-	case config.MCPStdio:
-		command, err := resolver.ResolveValue(m.Command)
-		if err != nil {
-			return nil, fmt.Errorf("invalid mcp command: %w", err)
-		}
-		if strings.TrimSpace(command) == "" {
-			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
-		}
-		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
-		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
-		return &mcp.CommandTransport{
-			Command: cmd,
-		}, nil
-	case config.MCPHttp:
-		if strings.TrimSpace(m.URL) == "" {
-			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
-		}
-		client := &http.Client{
-			Transport: &headerRoundTripper{
-				headers: m.ResolvedHeaders(),
-			},
-		}
-		return &mcp.StreamableClientTransport{
-			Endpoint:   m.URL,
-			HTTPClient: client,
-		}, nil
-	case config.MCPSSE:
-		if strings.TrimSpace(m.URL) == "" {
-			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
-		}
-		client := &http.Client{
-			Transport: &headerRoundTripper{
-				headers: m.ResolvedHeaders(),
-			},
-		}
-		return &mcp.SSEClientTransport{
-			Endpoint:   m.URL,
-			HTTPClient: client,
-		}, nil
-	default:
-		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
-	}
-}
-
-type headerRoundTripper struct {
-	headers map[string]string
-}
-
-func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
-	for k, v := range rt.headers {
-		req.Header.Set(k, v)
-	}
-	return http.DefaultTransport.RoundTrip(req)
-}
-
-func mcpTimeout(m config.MCPConfig) time.Duration {
-	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
-}
-
-func stdioMCPCheck(old *exec.Cmd) error {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	defer cancel()
-	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
-	cmd.Env = old.Env
-	out, err := cmd.CombinedOutput()
-	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
-		return nil
-	}
-	return fmt.Errorf("%w: %s", err, string(out))
-}

internal/agent/tools/mcp/init.go 🔗

@@ -0,0 +1,400 @@
+package mcp
+
+import (
+	"cmp"
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"log/slog"
+	"maps"
+	"net/http"
+	"os"
+	"os/exec"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/home"
+	"github.com/charmbracelet/crush/internal/permission"
+	"github.com/charmbracelet/crush/internal/pubsub"
+	"github.com/charmbracelet/crush/internal/version"
+	"github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+var (
+	sessions = csync.NewMap[string, *mcp.ClientSession]()
+	states   = csync.NewMap[string, ClientInfo]()
+	broker   = pubsub.NewBroker[Event]()
+)
+
+// State represents the current state of an MCP client
+type State int
+
+const (
+	StateDisabled State = iota
+	StateStarting
+	StateConnected
+	StateError
+)
+
+func (s State) String() string {
+	switch s {
+	case StateDisabled:
+		return "disabled"
+	case StateStarting:
+		return "starting"
+	case StateConnected:
+		return "connected"
+	case StateError:
+		return "error"
+	default:
+		return "unknown"
+	}
+}
+
+// EventType represents the type of MCP event
+type EventType string
+
+const (
+	EventStateChanged       EventType = "state_changed"
+	EventToolsListChanged   EventType = "tools_list_changed"
+	EventPromptsListChanged EventType = "prompts_list_changed"
+)
+
+// Event represents an event in the MCP system
+type Event struct {
+	Type   EventType
+	Name   string
+	State  State
+	Error  error
+	Counts Counts
+}
+
+// Counts number of available tools, prompts, etc.
+type Counts struct {
+	Tools   int
+	Prompts int
+}
+
+// ClientInfo holds information about an MCP client's state
+type ClientInfo struct {
+	Name        string
+	State       State
+	Error       error
+	Client      *mcp.ClientSession
+	Counts      Counts
+	ConnectedAt time.Time
+}
+
+// SubscribeEvents returns a channel for MCP events
+func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
+	return broker.Subscribe(ctx)
+}
+
+// GetStates returns the current state of all MCP clients
+func GetStates() map[string]ClientInfo {
+	return maps.Collect(states.Seq2())
+}
+
+// GetState returns the state of a specific MCP client
+func GetState(name string) (ClientInfo, bool) {
+	return states.Get(name)
+}
+
+// Close closes all MCP clients. This should be called during application shutdown.
+func Close() error {
+	var errs []error
+	for name, c := range sessions.Seq2() {
+		if err := c.Close(); err != nil &&
+			!errors.Is(err, io.EOF) &&
+			!errors.Is(err, context.Canceled) &&
+			err.Error() != "signal: killed" {
+			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
+		}
+	}
+	broker.Shutdown()
+	return errors.Join(errs...)
+}
+
+// Initialize initializes MCP clients based on the provided configuration.
+func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
+	var wg sync.WaitGroup
+	// Initialize states for all configured MCPs
+	for name, m := range cfg.MCP {
+		if m.Disabled {
+			updateState(name, StateDisabled, nil, nil, Counts{})
+			slog.Debug("skipping disabled mcp", "name", name)
+			continue
+		}
+
+		// Set initial starting state
+		updateState(name, StateStarting, nil, nil, Counts{})
+
+		wg.Add(1)
+		go func(name string, m config.MCPConfig) {
+			defer func() {
+				wg.Done()
+				if r := recover(); r != nil {
+					var err error
+					switch v := r.(type) {
+					case error:
+						err = v
+					case string:
+						err = fmt.Errorf("panic: %s", v)
+					default:
+						err = fmt.Errorf("panic: %v", v)
+					}
+					updateState(name, StateError, err, nil, Counts{})
+					slog.Error("panic in mcp client initialization", "error", err, "name", name)
+				}
+			}()
+
+			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+			defer cancel()
+
+			session, err := createSession(ctx, name, m, cfg.Resolver())
+			if err != nil {
+				return
+			}
+
+			tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir())
+			if err != nil {
+				slog.Error("error listing tools", "error", err)
+				updateState(name, StateError, err, nil, Counts{})
+				session.Close()
+				return
+			}
+
+			prompts, err := getPrompts(ctx, session)
+			if err != nil {
+				slog.Error("error listing prompts", "error", err)
+				updateState(name, StateError, err, nil, Counts{})
+				session.Close()
+				return
+			}
+
+			updateTools(name, tools)
+			updatePrompts(name, prompts)
+			sessions.Set(name, session)
+
+			updateState(name, StateConnected, nil, session, Counts{
+				Tools:   len(tools),
+				Prompts: len(prompts),
+			})
+		}(name, m)
+	}
+	wg.Wait()
+}
+
+func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
+	sess, ok := sessions.Get(name)
+	if !ok {
+		return nil, fmt.Errorf("mcp '%s' not available", name)
+	}
+
+	cfg := config.Get()
+	m := cfg.MCP[name]
+	state, _ := states.Get(name)
+
+	timeout := mcpTimeout(m)
+	pingCtx, cancel := context.WithTimeout(ctx, timeout)
+	defer cancel()
+	err := sess.Ping(pingCtx, nil)
+	if err == nil {
+		return sess, nil
+	}
+	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
+
+	sess, err = createSession(ctx, name, m, cfg.Resolver())
+	if err != nil {
+		return nil, err
+	}
+
+	updateState(name, StateConnected, nil, sess, state.Counts)
+	sessions.Set(name, sess)
+	return sess, nil
+}
+
+// updateState updates the state of an MCP client and publishes an event
+func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
+	info := ClientInfo{
+		Name:   name,
+		State:  state,
+		Error:  err,
+		Client: client,
+		Counts: counts,
+	}
+	switch state {
+	case StateConnected:
+		info.ConnectedAt = time.Now()
+	case StateError:
+		updateTools(name, nil)
+		sessions.Del(name)
+	}
+	states.Set(name, info)
+
+	// Publish state change event
+	broker.Publish(pubsub.UpdatedEvent, Event{
+		Type:   EventStateChanged,
+		Name:   name,
+		State:  state,
+		Error:  err,
+		Counts: counts,
+	})
+}
+
+func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
+	timeout := mcpTimeout(m)
+	mcpCtx, cancel := context.WithCancel(ctx)
+	cancelTimer := time.AfterFunc(timeout, cancel)
+
+	transport, err := createTransport(mcpCtx, m, resolver)
+	if err != nil {
+		updateState(name, StateError, err, nil, Counts{})
+		slog.Error("error creating mcp client", "error", err, "name", name)
+		cancel()
+		cancelTimer.Stop()
+		return nil, err
+	}
+
+	client := mcp.NewClient(
+		&mcp.Implementation{
+			Name:    "crush",
+			Version: version.Version,
+			Title:   "Crush",
+		},
+		&mcp.ClientOptions{
+			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
+				broker.Publish(pubsub.UpdatedEvent, Event{
+					Type: EventToolsListChanged,
+					Name: name,
+				})
+			},
+			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
+				broker.Publish(pubsub.UpdatedEvent, Event{
+					Type: EventPromptsListChanged,
+					Name: name,
+				})
+			},
+			KeepAlive: time.Minute * 10,
+		},
+	)
+
+	session, err := client.Connect(mcpCtx, transport, nil)
+	if err != nil {
+		err = maybeStdioErr(err, transport)
+		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
+		slog.Error("error starting mcp client", "error", err, "name", name)
+		cancel()
+		cancelTimer.Stop()
+		return nil, err
+	}
+
+	cancelTimer.Stop()
+	slog.Info("Initialized mcp client", "name", name)
+	return session, nil
+}
+
+// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
+// to parse, and the cli will then close it, causing the EOF error.
+// so, if we got an EOF err, and the transport is STDIO, we try to exec it
+// again with a timeout and collect the output so we can add details to the
+// error.
+// this happens particularly when starting things with npx, e.g. if node can't
+// be found or some other error like that.
+func maybeStdioErr(err error, transport mcp.Transport) error {
+	if !errors.Is(err, io.EOF) {
+		return err
+	}
+	ct, ok := transport.(*mcp.CommandTransport)
+	if !ok {
+		return err
+	}
+	if err2 := stdioCheck(ct.Command); err2 != nil {
+		err = errors.Join(err, err2)
+	}
+	return err
+}
+
+func maybeTimeoutErr(err error, timeout time.Duration) error {
+	if errors.Is(err, context.Canceled) {
+		return fmt.Errorf("timed out after %s", timeout)
+	}
+	return err
+}
+
+func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
+	switch m.Type {
+	case config.MCPStdio:
+		command, err := resolver.ResolveValue(m.Command)
+		if err != nil {
+			return nil, fmt.Errorf("invalid mcp command: %w", err)
+		}
+		if strings.TrimSpace(command) == "" {
+			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
+		}
+		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
+		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
+		return &mcp.CommandTransport{
+			Command: cmd,
+		}, nil
+	case config.MCPHttp:
+		if strings.TrimSpace(m.URL) == "" {
+			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
+		}
+		client := &http.Client{
+			Transport: &headerRoundTripper{
+				headers: m.ResolvedHeaders(),
+			},
+		}
+		return &mcp.StreamableClientTransport{
+			Endpoint:   m.URL,
+			HTTPClient: client,
+		}, nil
+	case config.MCPSSE:
+		if strings.TrimSpace(m.URL) == "" {
+			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
+		}
+		client := &http.Client{
+			Transport: &headerRoundTripper{
+				headers: m.ResolvedHeaders(),
+			},
+		}
+		return &mcp.SSEClientTransport{
+			Endpoint:   m.URL,
+			HTTPClient: client,
+		}, nil
+	default:
+		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
+	}
+}
+
+type headerRoundTripper struct {
+	headers map[string]string
+}
+
+func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	for k, v := range rt.headers {
+		req.Header.Set(k, v)
+	}
+	return http.DefaultTransport.RoundTrip(req)
+}
+
+func mcpTimeout(m config.MCPConfig) time.Duration {
+	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
+}
+
+func stdioCheck(old *exec.Cmd) error {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	defer cancel()
+	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
+	cmd.Env = old.Env
+	out, err := cmd.CombinedOutput()
+	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
+		return nil
+	}
+	return fmt.Errorf("%w: %s", err, string(out))
+}

internal/agent/tools/mcp/prompts.go 🔗

@@ -0,0 +1,83 @@
+package mcp
+
+import (
+	"context"
+	"iter"
+
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+type Prompt = mcp.Prompt
+
+var (
+	allPrompts     = csync.NewMap[string, *Prompt]()
+	client2Prompts = csync.NewMap[string, []*Prompt]()
+)
+
+// GetPrompts returns all available MCP prompts.
+func GetPrompts() iter.Seq2[string, *Prompt] {
+	return allPrompts.Seq2()
+}
+
+// GetPrompt returns a specific MCP prompt by name.
+func GetPrompt(name string) (*Prompt, bool) {
+	return allPrompts.Get(name)
+}
+
+// GetPromptsByClient returns all prompts for a specific MCP client.
+func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
+	return client2Prompts.Get(clientName)
+}
+
+// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
+func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
+	c, err := getOrRenewClient(ctx, clientName)
+	if err != nil {
+		return nil, err
+	}
+	result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
+		Name:      promptName,
+		Arguments: args,
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	var messages []string
+	for _, msg := range result.Messages {
+		if msg.Role != "user" {
+			continue
+		}
+		if textContent, ok := msg.Content.(*mcp.TextContent); ok {
+			messages = append(messages, textContent.Text)
+		}
+	}
+	return messages, nil
+}
+
+func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
+	if c.InitializeResult().Capabilities.Prompts == nil {
+		return nil, nil
+	}
+	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
+	if err != nil {
+		return nil, err
+	}
+	return result.Prompts, nil
+}
+
+// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
+func updatePrompts(mcpName string, prompts []*Prompt) {
+	if len(prompts) == 0 {
+		client2Prompts.Del(mcpName)
+	} else {
+		client2Prompts.Set(mcpName, prompts)
+	}
+	for mcpName, prompts := range client2Prompts.Seq2() {
+		for _, p := range prompts {
+			key := mcpName + ":" + p.Name
+			allPrompts.Set(key, p)
+		}
+	}
+}

internal/agent/tools/mcp/tools.go 🔗

@@ -0,0 +1,169 @@
+package mcp
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"iter"
+	"strings"
+
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/permission"
+	"github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+var (
+	allTools     = csync.NewMap[string, *Tool]()
+	client2Tools = csync.NewMap[string, []*Tool]()
+)
+
+// GetMCPTools returns all available MCP tools.
+func GetMCPTools() iter.Seq[*Tool] {
+	return allTools.Seq()
+}
+
+type Tool struct {
+	mcpName         string
+	tool            *mcp.Tool
+	permissions     permission.Service
+	workingDir      string
+	providerOptions fantasy.ProviderOptions
+}
+
+func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
+	m.providerOptions = opts
+}
+
+func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
+	return m.providerOptions
+}
+
+func (m *Tool) Name() string {
+	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
+}
+
+func (m *Tool) MCP() string {
+	return m.mcpName
+}
+
+func (m *Tool) MCPToolName() string {
+	return m.tool.Name
+}
+
+func (m *Tool) Info() fantasy.ToolInfo {
+	parameters := make(map[string]any)
+	required := make([]string, 0)
+
+	if input, ok := m.tool.InputSchema.(map[string]any); ok {
+		if props, ok := input["properties"].(map[string]any); ok {
+			parameters = props
+		}
+		if req, ok := input["required"].([]any); ok {
+			// Convert []any -> []string when elements are strings
+			for _, v := range req {
+				if s, ok := v.(string); ok {
+					required = append(required, s)
+				}
+			}
+		} else if reqStr, ok := input["required"].([]string); ok {
+			// Handle case where it's already []string
+			required = reqStr
+		}
+	}
+
+	return fantasy.ToolInfo{
+		Name:        fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
+		Description: m.tool.Description,
+		Parameters:  parameters,
+		Required:    required,
+	}
+}
+
+func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
+	sessionID := tools.GetSessionFromContext(ctx)
+	if sessionID == "" {
+		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
+	}
+	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
+	p := m.permissions.Request(
+		permission.CreatePermissionRequest{
+			SessionID:   sessionID,
+			ToolCallID:  params.ID,
+			Path:        m.workingDir,
+			ToolName:    m.Info().Name,
+			Action:      "execute",
+			Description: permissionDescription,
+			Params:      params.Input,
+		},
+	)
+	if !p {
+		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
+	}
+
+	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
+}
+
+func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
+	var args map[string]any
+	if err := json.Unmarshal([]byte(input), &args); err != nil {
+		return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+	}
+
+	c, err := getOrRenewClient(ctx, name)
+	if err != nil {
+		return fantasy.NewTextErrorResponse(err.Error()), nil
+	}
+	result, err := c.CallTool(ctx, &mcp.CallToolParams{
+		Name:      toolName,
+		Arguments: args,
+	})
+	if err != nil {
+		return fantasy.NewTextErrorResponse(err.Error()), nil
+	}
+
+	output := make([]string, 0, len(result.Content))
+	for _, v := range result.Content {
+		if vv, ok := v.(*mcp.TextContent); ok {
+			output = append(output, vv.Text)
+		} else {
+			output = append(output, fmt.Sprintf("%v", v))
+		}
+	}
+	return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
+}
+
+func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) {
+	if c.InitializeResult().Capabilities.Tools == nil {
+		return nil, nil
+	}
+	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
+	if err != nil {
+		return nil, err
+	}
+	mcpTools := make([]*Tool, 0, len(result.Tools))
+	for _, tool := range result.Tools {
+		mcpTools = append(mcpTools, &Tool{
+			mcpName:     name,
+			tool:        tool,
+			permissions: permissions,
+			workingDir:  workingDir,
+		})
+	}
+	return mcpTools, nil
+}
+
+// updateTools updates the global mcpTools and mcpClient2Tools maps
+func updateTools(mcpName string, tools []*Tool) {
+	if len(tools) == 0 {
+		client2Tools.Del(mcpName)
+	} else {
+		client2Tools.Set(mcpName, tools)
+	}
+	for _, tools := range client2Tools.Seq2() {
+		for _, t := range tools {
+			allTools.Set(t.Name(), t)
+		}
+	}
+}

internal/app/app.go 🔗

@@ -16,7 +16,7 @@ import (
 	"charm.land/fantasy"
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/crush/internal/agent"
-	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/db"
@@ -91,6 +91,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 	// Initialize LSP clients in the background.
 	app.initLSPClients(ctx)
 
+	go func() {
+		slog.Info("Initializing MCP clients")
+		mcp.Initialize(ctx, app.Permissions, cfg)
+	}()
+
 	// cleanup database upon app shutdown
 	app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
 
@@ -260,7 +265,7 @@ func (app *App) setupEvents() {
 	setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
 	setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
 	setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
-	setupSubscriber(ctx, app.serviceEventsWG, "mcp", tools.SubscribeMCPEvents, app.events)
+	setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
 	setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
 	cleanupFunc := func() error {
 		cancel()
@@ -324,7 +329,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
 	}
 
 	// Add MCP client cleanup to shutdown process
-	app.cleanupFuncs = append(app.cleanupFuncs, tools.CloseMCPClients)
+	app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close)
 	return nil
 }
 

internal/tui/components/dialogs/commands/arguments.go 🔗

@@ -1,8 +1,7 @@
 package commands
 
 import (
-	"fmt"
-	"strings"
+	"cmp"
 
 	"github.com/charmbracelet/bubbles/v2/help"
 	"github.com/charmbracelet/bubbles/v2/key"
@@ -20,9 +19,10 @@ const (
 
 // ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
 type ShowArgumentsDialogMsg struct {
-	CommandID string
-	Content   string
-	ArgNames  []string
+	CommandID   string
+	Description string
+	ArgNames    []string
+	OnSubmit    func(args map[string]string) tea.Cmd
 }
 
 // CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
@@ -39,26 +39,39 @@ type CommandArgumentsDialog interface {
 }
 
 type commandArgumentsDialogCmp struct {
-	width   int
-	wWidth  int // Width of the terminal window
-	wHeight int // Height of the terminal window
-
-	inputs     []textinput.Model
-	focusIndex int
-	keys       ArgumentsDialogKeyMap
-	commandID  string
-	content    string
-	argNames   []string
-	help       help.Model
+	wWidth, wHeight int
+	width, height   int
+
+	inputs    []textinput.Model
+	focused   int
+	keys      ArgumentsDialogKeyMap
+	arguments []Argument
+	help      help.Model
+
+	id          string
+	title       string
+	name        string
+	description string
+
+	onSubmit func(args map[string]string) tea.Cmd
+}
+
+type Argument struct {
+	Name, Title, Description string
+	Required                 bool
 }
 
-func NewCommandArgumentsDialog(commandID, content string, argNames []string) CommandArgumentsDialog {
+func NewCommandArgumentsDialog(
+	id, title, name, description string,
+	arguments []Argument,
+	onSubmit func(args map[string]string) tea.Cmd,
+) CommandArgumentsDialog {
 	t := styles.CurrentTheme()
-	inputs := make([]textinput.Model, len(argNames))
+	inputs := make([]textinput.Model, len(arguments))
 
-	for i, name := range argNames {
+	for i, arg := range arguments {
 		ti := textinput.New()
-		ti.Placeholder = fmt.Sprintf("Enter value for %s...", name)
+		ti.Placeholder = cmp.Or(arg.Description, "Enter value for "+arg.Title)
 		ti.SetWidth(40)
 		ti.SetVirtualCursor(false)
 		ti.Prompt = ""
@@ -75,14 +88,16 @@ func NewCommandArgumentsDialog(commandID, content string, argNames []string) Com
 	}
 
 	return &commandArgumentsDialogCmp{
-		inputs:     inputs,
-		keys:       DefaultArgumentsDialogKeyMap(),
-		commandID:  commandID,
-		content:    content,
-		argNames:   argNames,
-		focusIndex: 0,
-		width:      60,
-		help:       help.New(),
+		inputs:      inputs,
+		keys:        DefaultArgumentsDialogKeyMap(),
+		id:          id,
+		name:        name,
+		title:       title,
+		description: description,
+		arguments:   arguments,
+		width:       60,
+		help:        help.New(),
+		onSubmit:    onSubmit,
 	}
 }
 
@@ -97,47 +112,51 @@ func (c *commandArgumentsDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 	case tea.WindowSizeMsg:
 		c.wWidth = msg.Width
 		c.wHeight = msg.Height
+		c.width = min(90, c.wWidth)
+		c.height = min(15, c.wHeight)
+		for i := range c.inputs {
+			c.inputs[i].SetWidth(c.width - (paddingHorizontal * 2))
+		}
 	case tea.KeyPressMsg:
 		switch {
+		case key.Matches(msg, c.keys.Close):
+			return c, util.CmdHandler(dialogs.CloseDialogMsg{})
 		case key.Matches(msg, c.keys.Confirm):
-			if c.focusIndex == len(c.inputs)-1 {
-				content := c.content
-				for i, name := range c.argNames {
+			if c.focused == len(c.inputs)-1 {
+				args := make(map[string]string)
+				for i, arg := range c.arguments {
 					value := c.inputs[i].Value()
-					placeholder := "$" + name
-					content = strings.ReplaceAll(content, placeholder, value)
+					args[arg.Name] = value
 				}
 				return c, tea.Sequence(
 					util.CmdHandler(dialogs.CloseDialogMsg{}),
-					util.CmdHandler(CommandRunCustomMsg{
-						Content: content,
-					}),
+					c.onSubmit(args),
 				)
 			}
 			// Otherwise, move to the next input
-			c.inputs[c.focusIndex].Blur()
-			c.focusIndex++
-			c.inputs[c.focusIndex].Focus()
+			c.inputs[c.focused].Blur()
+			c.focused++
+			c.inputs[c.focused].Focus()
 		case key.Matches(msg, c.keys.Next):
 			// Move to the next input
-			c.inputs[c.focusIndex].Blur()
-			c.focusIndex = (c.focusIndex + 1) % len(c.inputs)
-			c.inputs[c.focusIndex].Focus()
+			c.inputs[c.focused].Blur()
+			c.focused = (c.focused + 1) % len(c.inputs)
+			c.inputs[c.focused].Focus()
 		case key.Matches(msg, c.keys.Previous):
 			// Move to the previous input
-			c.inputs[c.focusIndex].Blur()
-			c.focusIndex = (c.focusIndex - 1 + len(c.inputs)) % len(c.inputs)
-			c.inputs[c.focusIndex].Focus()
+			c.inputs[c.focused].Blur()
+			c.focused = (c.focused - 1 + len(c.inputs)) % len(c.inputs)
+			c.inputs[c.focused].Focus()
 		case key.Matches(msg, c.keys.Close):
 			return c, util.CmdHandler(dialogs.CloseDialogMsg{})
 		default:
 			var cmd tea.Cmd
-			c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg)
+			c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg)
 			return c, cmd
 		}
 	case tea.PasteMsg:
 		var cmd tea.Cmd
-		c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg)
+		c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg)
 		return c, cmd
 	}
 	return c, nil
@@ -152,26 +171,28 @@ func (c *commandArgumentsDialogCmp) View() string {
 		Foreground(t.Primary).
 		Bold(true).
 		Padding(0, 1).
-		Render("Command Arguments")
+		Render(cmp.Or(c.title, c.name))
 
-	explanation := t.S().Text.
+	promptName := t.S().Text.
 		Padding(0, 1).
-		Render("This command requires arguments.")
+		Render(c.description)
 
-	// Create input fields for each argument
 	inputFields := make([]string, len(c.inputs))
 	for i, input := range c.inputs {
-		// Highlight the label of the focused input
-		labelStyle := baseStyle.
-			Padding(1, 1, 0, 1)
+		labelStyle := baseStyle.Padding(1, 1, 0, 1)
 
-		if i == c.focusIndex {
+		if i == c.focused {
 			labelStyle = labelStyle.Foreground(t.FgBase).Bold(true)
 		} else {
 			labelStyle = labelStyle.Foreground(t.FgMuted)
 		}
 
-		label := labelStyle.Render(c.argNames[i] + ":")
+		arg := c.arguments[i]
+		argName := cmp.Or(arg.Title, arg.Name)
+		if arg.Required {
+			argName += "*"
+		}
+		label := labelStyle.Render(argName + ":")
 
 		field := t.S().Text.
 			Padding(0, 1).
@@ -180,18 +201,14 @@ func (c *commandArgumentsDialogCmp) View() string {
 		inputFields[i] = lipgloss.JoinVertical(lipgloss.Left, label, field)
 	}
 
-	// Join all elements vertically
-	elements := []string{title, explanation}
+	elements := []string{title, promptName}
 	elements = append(elements, inputFields...)
 
 	c.help.ShowAll = false
 	helpText := baseStyle.Padding(0, 1).Render(c.help.View(c.keys))
 	elements = append(elements, "", helpText)
 
-	content := lipgloss.JoinVertical(
-		lipgloss.Left,
-		elements...,
-	)
+	content := lipgloss.JoinVertical(lipgloss.Left, elements...)
 
 	return baseStyle.Padding(1, 1, 0, 1).
 		Border(lipgloss.RoundedBorder()).
@@ -201,26 +218,33 @@ func (c *commandArgumentsDialogCmp) View() string {
 }
 
 func (c *commandArgumentsDialogCmp) Cursor() *tea.Cursor {
-	cursor := c.inputs[c.focusIndex].Cursor()
+	if len(c.inputs) == 0 {
+		return nil
+	}
+	cursor := c.inputs[c.focused].Cursor()
 	if cursor != nil {
 		cursor = c.moveCursor(cursor)
 	}
 	return cursor
 }
 
+const (
+	headerHeight      = 3
+	itemHeight        = 3
+	paddingHorizontal = 3
+)
+
 func (c *commandArgumentsDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
 	row, col := c.Position()
-	offset := row + 3 + (1+c.focusIndex)*3
+	offset := row + headerHeight + (1+c.focused)*itemHeight
 	cursor.Y += offset
-	cursor.X = cursor.X + col + 3
+	cursor.X = cursor.X + col + paddingHorizontal
 	return cursor
 }
 
 func (c *commandArgumentsDialogCmp) Position() (int, int) {
-	row := c.wHeight / 2
-	row -= c.wHeight / 2
-	col := c.wWidth / 2
-	col -= c.width / 2
+	row := (c.wHeight / 2) - (c.height / 2)
+	col := (c.wWidth / 2) - (c.width / 2)
 	return row, col
 }
 

internal/tui/components/dialogs/commands/commands.go 🔗

@@ -2,6 +2,8 @@ package commands
 
 import (
 	"os"
+	"slices"
+	"strings"
 
 	"github.com/charmbracelet/bubbles/v2/help"
 	"github.com/charmbracelet/bubbles/v2/key"
@@ -10,7 +12,10 @@ import (
 	"github.com/charmbracelet/lipgloss/v2"
 
 	"github.com/charmbracelet/crush/internal/agent"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/pubsub"
 	"github.com/charmbracelet/crush/internal/tui/components/chat"
 	"github.com/charmbracelet/crush/internal/tui/components/core"
 	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
@@ -25,9 +30,14 @@ const (
 	defaultWidth int = 70
 )
 
+type commandType uint
+
+func (c commandType) String() string { return []string{"System", "User", "MCP"}[c] }
+
 const (
-	SystemCommands int = iota
+	SystemCommands commandType = iota
 	UserCommands
+	MCPPrompts
 )
 
 type listModel = list.FilterableList[list.CompletionItem[Command]]
@@ -54,9 +64,10 @@ type commandDialogCmp struct {
 	commandList  listModel
 	keyMap       CommandsDialogKeyMap
 	help         help.Model
-	commandType  int       // SystemCommands or UserCommands
-	userCommands []Command // User-defined commands
-	sessionID    string    // Current session ID
+	selected     commandType           // Selected SystemCommands, UserCommands, or MCPPrompts
+	userCommands []Command             // User-defined commands
+	mcpPrompts   *csync.Slice[Command] // MCP prompts
+	sessionID    string                // Current session ID
 }
 
 type (
@@ -102,8 +113,9 @@ func NewCommandDialog(sessionID string) CommandsDialog {
 		width:       defaultWidth,
 		keyMap:      DefaultCommandsDialogKeyMap(),
 		help:        help,
-		commandType: SystemCommands,
+		selected:    SystemCommands,
 		sessionID:   sessionID,
+		mcpPrompts:  csync.NewSlice[Command](),
 	}
 }
 
@@ -113,7 +125,8 @@ func (c *commandDialogCmp) Init() tea.Cmd {
 		return util.ReportError(err)
 	}
 	c.userCommands = commands
-	return c.SetCommandType(c.commandType)
+	c.mcpPrompts.SetSlice(loadMCPPrompts())
+	return c.setCommandType(c.selected)
 }
 
 func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
@@ -122,9 +135,19 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 		c.wWidth = msg.Width
 		c.wHeight = msg.Height
 		return c, tea.Batch(
-			c.SetCommandType(c.commandType),
+			c.setCommandType(c.selected),
 			c.commandList.SetSize(c.listWidth(), c.listHeight()),
 		)
+	case pubsub.Event[mcp.Event]:
+		// Reload MCP prompts when MCP state changes
+		if msg.Type == pubsub.UpdatedEvent {
+			c.mcpPrompts.SetSlice(loadMCPPrompts())
+			// If we're currently viewing MCP prompts, refresh the list
+			if c.selected == MCPPrompts {
+				return c, c.setCommandType(MCPPrompts)
+			}
+			return c, nil
+		}
 	case tea.KeyPressMsg:
 		switch {
 		case key.Matches(msg, c.keyMap.Select):
@@ -138,15 +161,10 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 				command.Handler(command),
 			)
 		case key.Matches(msg, c.keyMap.Tab):
-			if len(c.userCommands) == 0 {
+			if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 {
 				return c, nil
 			}
-			// Toggle command type between System and User commands
-			if c.commandType == SystemCommands {
-				return c, c.SetCommandType(UserCommands)
-			} else {
-				return c, c.SetCommandType(SystemCommands)
-			}
+			return c, c.setCommandType(c.next())
 		case key.Matches(msg, c.keyMap.Close):
 			return c, util.CmdHandler(dialogs.CloseDialogMsg{})
 		default:
@@ -158,13 +176,35 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 	return c, nil
 }
 
+func (c *commandDialogCmp) next() commandType {
+	switch c.selected {
+	case SystemCommands:
+		if len(c.userCommands) > 0 {
+			return UserCommands
+		}
+		if c.mcpPrompts.Len() > 0 {
+			return MCPPrompts
+		}
+		fallthrough
+	case UserCommands:
+		if c.mcpPrompts.Len() > 0 {
+			return MCPPrompts
+		}
+		fallthrough
+	case MCPPrompts:
+		return SystemCommands
+	default:
+		return SystemCommands
+	}
+}
+
 func (c *commandDialogCmp) View() string {
 	t := styles.CurrentTheme()
 	listView := c.commandList
 	radio := c.commandTypeRadio()
 
 	header := t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-lipgloss.Width(radio)-5) + " " + radio)
-	if len(c.userCommands) == 0 {
+	if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 {
 		header = t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-4))
 	}
 	content := lipgloss.JoinVertical(
@@ -190,27 +230,41 @@ func (c *commandDialogCmp) Cursor() *tea.Cursor {
 
 func (c *commandDialogCmp) commandTypeRadio() string {
 	t := styles.CurrentTheme()
-	choices := []string{"System", "User"}
-	iconSelected := "◉"
-	iconUnselected := "○"
-	if c.commandType == SystemCommands {
-		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
+
+	fn := func(i commandType) string {
+		if i == c.selected {
+			return "◉ " + i.String()
+		}
+		return "○ " + i.String()
+	}
+
+	parts := []string{
+		fn(SystemCommands),
+	}
+	if len(c.userCommands) > 0 {
+		parts = append(parts, fn(UserCommands))
+	}
+	if c.mcpPrompts.Len() > 0 {
+		parts = append(parts, fn(MCPPrompts))
 	}
-	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
+	return t.S().Base.Foreground(t.FgHalfMuted).Render(strings.Join(parts, " "))
 }
 
 func (c *commandDialogCmp) listWidth() int {
 	return defaultWidth - 2 // 4 for padding
 }
 
-func (c *commandDialogCmp) SetCommandType(commandType int) tea.Cmd {
-	c.commandType = commandType
+func (c *commandDialogCmp) setCommandType(commandType commandType) tea.Cmd {
+	c.selected = commandType
 
 	var commands []Command
-	if c.commandType == SystemCommands {
+	switch c.selected {
+	case SystemCommands:
 		commands = c.defaultCommands()
-	} else {
+	case UserCommands:
 		commands = c.userCommands
+	case MCPPrompts:
+		commands = slices.Collect(c.mcpPrompts.Seq())
 	}
 
 	commandItems := []list.CompletionItem[Command]{}

internal/tui/components/dialogs/commands/loader.go 🔗

@@ -1,22 +1,26 @@
 package commands
 
 import (
+	"context"
 	"fmt"
 	"io/fs"
+	"log/slog"
 	"os"
 	"path/filepath"
 	"regexp"
 	"strings"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/home"
+	"github.com/charmbracelet/crush/internal/tui/components/chat"
 	"github.com/charmbracelet/crush/internal/tui/util"
 )
 
 const (
-	UserCommandPrefix    = "user:"
-	ProjectCommandPrefix = "project:"
+	userCommandPrefix    = "user:"
+	projectCommandPrefix = "project:"
 )
 
 var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
@@ -50,7 +54,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
 	if dir := getXDGCommandsDir(); dir != "" {
 		sources = append(sources, commandSource{
 			path:   dir,
-			prefix: UserCommandPrefix,
+			prefix: userCommandPrefix,
 		})
 	}
 
@@ -58,14 +62,14 @@ func buildCommandSources(cfg *config.Config) []commandSource {
 	if home := home.Dir(); home != "" {
 		sources = append(sources, commandSource{
 			path:   filepath.Join(home, ".crush", "commands"),
-			prefix: UserCommandPrefix,
+			prefix: userCommandPrefix,
 		})
 	}
 
 	// Project directory
 	sources = append(sources, commandSource{
 		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
-		prefix: ProjectCommandPrefix,
+		prefix: projectCommandPrefix,
 	})
 
 	return sources
@@ -127,12 +131,13 @@ func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, erro
 	}
 
 	id := buildCommandID(path, baseDir, prefix)
+	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
 
 	return Command{
 		ID:          id,
 		Title:       id,
-		Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
-		Handler:     createCommandHandler(id, string(content)),
+		Description: desc,
+		Handler:     createCommandHandler(id, desc, string(content)),
 	}, nil
 }
 
@@ -149,21 +154,35 @@ func buildCommandID(path, baseDir, prefix string) string {
 	return prefix + strings.Join(parts, ":")
 }
 
-func createCommandHandler(id string, content string) func(Command) tea.Cmd {
+func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
 	return func(cmd Command) tea.Cmd {
 		args := extractArgNames(content)
 
-		if len(args) > 0 {
-			return util.CmdHandler(ShowArgumentsDialogMsg{
-				CommandID: id,
-				Content:   content,
-				ArgNames:  args,
+		if len(args) == 0 {
+			return util.CmdHandler(CommandRunCustomMsg{
+				Content: content,
 			})
 		}
+		return util.CmdHandler(ShowArgumentsDialogMsg{
+			CommandID:   id,
+			Description: desc,
+			ArgNames:    args,
+			OnSubmit: func(args map[string]string) tea.Cmd {
+				return execUserPrompt(content, args)
+			},
+		})
+	}
+}
 
-		return util.CmdHandler(CommandRunCustomMsg{
+func execUserPrompt(content string, args map[string]string) tea.Cmd {
+	return func() tea.Msg {
+		for name, value := range args {
+			placeholder := "$" + name
+			content = strings.ReplaceAll(content, placeholder, value)
+		}
+		return CommandRunCustomMsg{
 			Content: content,
-		})
+		}
 	}
 }
 
@@ -201,3 +220,55 @@ func isMarkdownFile(name string) bool {
 type CommandRunCustomMsg struct {
 	Content string
 }
+
+func loadMCPPrompts() []Command {
+	var commands []Command
+	for key, prompt := range mcp.GetPrompts() {
+		clientName, promptName, ok := strings.Cut(key, ":")
+		if !ok {
+			slog.Warn("prompt not found", "key", key)
+			continue
+		}
+		commands = append(commands, Command{
+			ID:          key,
+			Title:       clientName + ":" + promptName,
+			Description: prompt.Description,
+			Handler:     createMCPPromptHandler(clientName, promptName, prompt),
+		})
+	}
+
+	return commands
+}
+
+func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
+	return func(cmd Command) tea.Cmd {
+		if len(prompt.Arguments) == 0 {
+			return execMCPPrompt(clientName, promptName, nil)
+		}
+		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
+			Prompt: prompt,
+			OnSubmit: func(args map[string]string) tea.Cmd {
+				return execMCPPrompt(clientName, promptName, args)
+			},
+		})
+	}
+}
+
+func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
+	return func() tea.Msg {
+		ctx := context.Background()
+		result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
+		if err != nil {
+			return util.ReportError(err)
+		}
+
+		return chat.SendMsg{
+			Text: strings.Join(result, " "),
+		}
+	}
+}
+
+type ShowMCPPromptArgumentsDialogMsg struct {
+	Prompt   *mcp.Prompt
+	OnSubmit func(arg map[string]string) tea.Cmd
+}

internal/tui/components/mcp/mcp.go 🔗

@@ -2,10 +2,11 @@ package mcp
 
 import (
 	"fmt"
+	"strings"
 
 	"github.com/charmbracelet/lipgloss/v2"
 
-	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/tui/components/core"
 	"github.com/charmbracelet/crush/internal/tui/styles"
@@ -40,7 +41,7 @@ func RenderMCPList(opts RenderOptions) []string {
 	}
 
 	// Get MCP states
-	mcpStates := tools.GetMCPStates()
+	mcpStates := mcp.GetStates()
 
 	// Determine how many items to show
 	maxItems := len(mcps)
@@ -56,21 +57,24 @@ func RenderMCPList(opts RenderOptions) []string {
 		// Determine icon and color based on state
 		icon := t.ItemOfflineIcon
 		description := ""
-		extraContent := ""
+		extraContent := []string{}
 
 		if state, exists := mcpStates[l.Name]; exists {
 			switch state.State {
-			case tools.MCPStateDisabled:
+			case mcp.StateDisabled:
 				description = t.S().Subtle.Render("disabled")
-			case tools.MCPStateStarting:
+			case mcp.StateStarting:
 				icon = t.ItemBusyIcon
 				description = t.S().Subtle.Render("starting...")
-			case tools.MCPStateConnected:
+			case mcp.StateConnected:
 				icon = t.ItemOnlineIcon
-				if state.ToolCount > 0 {
-					extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount))
+				if count := state.Counts.Tools; count > 0 {
+					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count)))
 				}
-			case tools.MCPStateError:
+				if count := state.Counts.Prompts; count > 0 {
+					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count)))
+				}
+			case mcp.StateError:
 				icon = t.ItemErrorIcon
 				if state.Error != nil {
 					description = t.S().Subtle.Render(fmt.Sprintf("error: %s", state.Error.Error()))
@@ -88,7 +92,7 @@ func RenderMCPList(opts RenderOptions) []string {
 					Icon:         icon.String(),
 					Title:        l.Name,
 					Description:  description,
-					ExtraContent: extraContent,
+					ExtraContent: strings.Join(extraContent, " "),
 				},
 				opts.MaxWidth,
 			),

internal/tui/tui.go 🔗

@@ -33,6 +33,8 @@ import (
 	"github.com/charmbracelet/crush/internal/tui/styles"
 	"github.com/charmbracelet/crush/internal/tui/util"
 	"github.com/charmbracelet/lipgloss/v2"
+	"golang.org/x/text/cases"
+	"golang.org/x/text/language"
 )
 
 var lastMouseEvent time.Time
@@ -156,15 +158,44 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		a.dialog = u.(dialogs.DialogCmp)
 		return a, tea.Batch(completionCmd, dialogCmd)
 	case commands.ShowArgumentsDialogMsg:
+		var args []commands.Argument
+		for _, arg := range msg.ArgNames {
+			args = append(args, commands.Argument{
+				Name:     arg,
+				Title:    cases.Title(language.English).String(arg),
+				Required: true,
+			})
+		}
 		return a, util.CmdHandler(
 			dialogs.OpenDialogMsg{
 				Model: commands.NewCommandArgumentsDialog(
 					msg.CommandID,
-					msg.Content,
-					msg.ArgNames,
+					msg.CommandID,
+					msg.CommandID,
+					msg.Description,
+					args,
+					msg.OnSubmit,
 				),
 			},
 		)
+	case commands.ShowMCPPromptArgumentsDialogMsg:
+		args := make([]commands.Argument, 0, len(msg.Prompt.Arguments))
+		for _, arg := range msg.Prompt.Arguments {
+			args = append(args, commands.Argument(*arg))
+		}
+		dialog := commands.NewCommandArgumentsDialog(
+			msg.Prompt.Name,
+			msg.Prompt.Title,
+			msg.Prompt.Name,
+			msg.Prompt.Description,
+			args,
+			msg.OnSubmit,
+		)
+		return a, util.CmdHandler(
+			dialogs.OpenDialogMsg{
+				Model: dialog,
+			},
+		)
 	// Page change messages
 	case page.PageChangeMsg:
 		return a, a.moveToPage(msg.ID)