fix(mcp): tool/prompt list update

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/agent/tools/mcp-tools.go                  |  2 
internal/agent/tools/mcp/init.go                   |  3 
internal/agent/tools/mcp/prompts.go                | 49 ++++++++++-----
internal/agent/tools/mcp/tools.go                  | 41 ++++++++++--
internal/app/app.go                                |  5 -
internal/tui/components/dialogs/commands/loader.go |  2 
internal/tui/tui.go                                | 23 +++++++
7 files changed, 94 insertions(+), 31 deletions(-)

Detailed changes

internal/agent/tools/mcp-tools.go 🔗

@@ -12,7 +12,7 @@ import (
 // GetMCPTools gets all the currently available MCP tools.
 func GetMCPTools(permissions permission.Service, wd string) []*Tool {
 	var result []*Tool
-	for name, tool := range mcp.GetMCPTools() {
+	for name, tool := range mcp.Tools() {
 		result = append(result, &Tool{
 			mcpName:     name,
 			tool:        tool,

internal/agent/tools/mcp/init.go 🔗

@@ -281,6 +281,9 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
 					Name: name,
 				})
 			},
+			LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
+				slog.Info("mcp log", "name", name, "data", req.Params.Data)
+			},
 			KeepAlive: time.Minute * 10,
 		},
 	)

internal/agent/tools/mcp/prompts.go 🔗

@@ -3,6 +3,7 @@ package mcp
 import (
 	"context"
 	"iter"
+	"log/slog"
 
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -11,25 +12,15 @@ import (
 type Prompt = mcp.Prompt
 
 var (
-	allPrompts     = csync.NewMap[string, *Prompt]()
-	client2Prompts = csync.NewMap[string, []*Prompt]()
+	allPrompts    = csync.NewMap[string, *Prompt]()
+	clientPrompts = csync.NewMap[string, []*Prompt]()
 )
 
-// GetPrompts returns all available MCP prompts.
-func GetPrompts() iter.Seq2[string, *Prompt] {
+// Prompts returns all available MCP prompts.
+func Prompts() iter.Seq2[string, *Prompt] {
 	return allPrompts.Seq2()
 }
 
-// GetPrompt returns a specific MCP prompt by name.
-func GetPrompt(name string) (*Prompt, bool) {
-	return allPrompts.Get(name)
-}
-
-// GetPromptsByClient returns all prompts for a specific MCP client.
-func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
-	return client2Prompts.Get(clientName)
-}
-
 // GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
 func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
 	c, err := getOrRenewClient(ctx, clientName)
@@ -56,6 +47,30 @@ func GetPromptMessages(ctx context.Context, clientName, promptName string, args
 	return messages, nil
 }
 
+// RefreshPrompts gets the updated list of prompts from the MCP and updates the
+// global state.
+func RefreshPrompts(ctx context.Context, name string) {
+	session, ok := sessions.Get(name)
+	if !ok {
+		slog.Warn("refresh prompts: no session", "name", name)
+		return
+	}
+
+	prompts, err := getPrompts(ctx, session)
+	if err != nil {
+		updateState(name, StateError, err, nil, Counts{})
+		return
+	}
+
+	updatePrompts(name, prompts)
+
+	prev, _ := states.Get(name)
+	updateState(name, StateConnected, nil, session, Counts{
+		Prompts: len(prompts),
+		Tools:   prev.Counts.Tools,
+	})
+}
+
 func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
 	if c.InitializeResult().Capabilities.Prompts == nil {
 		return nil, nil
@@ -70,11 +85,11 @@ func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
 // updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
 func updatePrompts(mcpName string, prompts []*Prompt) {
 	if len(prompts) == 0 {
-		client2Prompts.Del(mcpName)
+		clientPrompts.Del(mcpName)
 	} else {
-		client2Prompts.Set(mcpName, prompts)
+		clientPrompts.Set(mcpName, prompts)
 	}
-	for mcpName, prompts := range client2Prompts.Seq2() {
+	for mcpName, prompts := range clientPrompts.Seq2() {
 		for _, p := range prompts {
 			key := mcpName + ":" + p.Name
 			allPrompts.Set(key, p)

internal/agent/tools/mcp/tools.go 🔗

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"iter"
+	"log/slog"
 	"strings"
 
 	"github.com/charmbracelet/crush/internal/csync"
@@ -14,12 +15,12 @@ import (
 type Tool = mcp.Tool
 
 var (
-	allTools     = csync.NewMap[string, *Tool]()
-	client2Tools = csync.NewMap[string, []*Tool]()
+	allTools    = csync.NewMap[string, *Tool]()
+	clientTools = csync.NewMap[string, []*Tool]()
 )
 
-// GetTools returns all available MCP tools.
-func GetTools() iter.Seq2[string, *Tool] {
+// Tools returns all available MCP tools.
+func Tools() iter.Seq2[string, *Tool] {
 	return allTools.Seq2()
 }
 
@@ -53,6 +54,30 @@ func RunTool(ctx context.Context, name, toolName string, input string) (string,
 	return strings.Join(output, "\n"), nil
 }
 
+// RefreshTools gets the updated list of tools from the MCP and updates the
+// global state.
+func RefreshTools(ctx context.Context, name string) {
+	session, ok := sessions.Get(name)
+	if !ok {
+		slog.Warn("refresh tools: no session", "name", name)
+		return
+	}
+
+	tools, err := getTools(ctx, session)
+	if err != nil {
+		updateState(name, StateError, err, nil, Counts{})
+		return
+	}
+
+	updateTools(name, tools)
+
+	prev, _ := states.Get(name)
+	updateState(name, StateConnected, nil, session, Counts{
+		Tools:   len(tools),
+		Prompts: prev.Counts.Prompts,
+	})
+}
+
 func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
 	if session.InitializeResult().Capabilities.Tools == nil {
 		return nil, nil
@@ -65,13 +90,13 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error)
 }
 
 // updateTools updates the global mcpTools and mcpClient2Tools maps
-func updateTools(mcpName string, tools []*Tool) {
+func updateTools(name string, tools []*Tool) {
 	if len(tools) == 0 {
-		client2Tools.Del(mcpName)
+		clientTools.Del(name)
 	} else {
-		client2Tools.Set(mcpName, tools)
+		clientTools.Set(name, tools)
 	}
-	for name, tools := range client2Tools.Seq2() {
+	for name, tools := range clientTools.Seq2() {
 		for _, t := range tools {
 			allTools.Set(name, t)
 		}

internal/app/app.go 🔗

@@ -97,7 +97,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 	}()
 
 	// cleanup database upon app shutdown
-	app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
+	app.cleanupFuncs = append(app.cleanupFuncs, conn.Close, mcp.Close)
 
 	// TODO: remove the concept of agent config, most likely.
 	if !cfg.IsConfigured() {
@@ -327,9 +327,6 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
 		slog.Error("Failed to create coder agent", "err", err)
 		return err
 	}
-
-	// Add MCP client cleanup to shutdown process
-	app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close)
 	return nil
 }
 

internal/tui/components/dialogs/commands/loader.go 🔗

@@ -223,7 +223,7 @@ type CommandRunCustomMsg struct {
 
 func loadMCPPrompts() []Command {
 	var commands []Command
-	for key, prompt := range mcp.GetPrompts() {
+	for key, prompt := range mcp.Prompts() {
 		clientName, promptName, ok := strings.Cut(key, ":")
 		if !ok {
 			slog.Warn("prompt not found", "key", key)

internal/tui/tui.go 🔗

@@ -10,6 +10,7 @@ import (
 
 	"github.com/charmbracelet/bubbles/v2/key"
 	tea "github.com/charmbracelet/bubbletea/v2"
+	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/event"
@@ -140,6 +141,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		a.completions.Update(msg)
 		return a, a.handleWindowResize(msg.Width, msg.Height)
 
+	case pubsub.Event[mcp.Event]:
+		switch msg.Payload.Type {
+		case mcp.EventPromptsListChanged:
+
+		case mcp.EventToolsListChanged:
+			return a, a.handleMCPToolsEvent(context.Background(), msg.Payload.Name)
+		}
+
 	// Completions messages
 	case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg,
 		completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg:
@@ -618,6 +627,20 @@ func (a *appModel) View() tea.View {
 	return view
 }
 
+func (a *appModel) handleMCPPromptsEvent(ctx context.Context, name string) tea.Cmd {
+	return func() tea.Msg {
+		mcp.RefreshPrompts(ctx, name)
+		return nil
+	}
+}
+
+func (a *appModel) handleMCPToolsEvent(ctx context.Context, name string) tea.Cmd {
+	return func() tea.Msg {
+		mcp.RefreshTools(ctx, name)
+		return nil
+	}
+}
+
 // New creates and initializes a new TUI application model.
 func New(app *app.App) *appModel {
 	chatPage := chat.New(app)