Detailed changes
@@ -12,7 +12,7 @@ import (
// GetMCPTools gets all the currently available MCP tools.
func GetMCPTools(permissions permission.Service, wd string) []*Tool {
var result []*Tool
- for name, tool := range mcp.GetMCPTools() {
+ for name, tool := range mcp.Tools() {
result = append(result, &Tool{
mcpName: name,
tool: tool,
@@ -281,6 +281,9 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
Name: name,
})
},
+ LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
+ slog.Info("mcp log", "name", name, "data", req.Params.Data)
+ },
KeepAlive: time.Minute * 10,
},
)
@@ -3,6 +3,7 @@ package mcp
import (
"context"
"iter"
+ "log/slog"
"github.com/charmbracelet/crush/internal/csync"
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -11,25 +12,15 @@ import (
type Prompt = mcp.Prompt
var (
- allPrompts = csync.NewMap[string, *Prompt]()
- client2Prompts = csync.NewMap[string, []*Prompt]()
+ allPrompts = csync.NewMap[string, *Prompt]()
+ clientPrompts = csync.NewMap[string, []*Prompt]()
)
-// GetPrompts returns all available MCP prompts.
-func GetPrompts() iter.Seq2[string, *Prompt] {
+// Prompts returns all available MCP prompts.
+func Prompts() iter.Seq2[string, *Prompt] {
return allPrompts.Seq2()
}
-// GetPrompt returns a specific MCP prompt by name.
-func GetPrompt(name string) (*Prompt, bool) {
- return allPrompts.Get(name)
-}
-
-// GetPromptsByClient returns all prompts for a specific MCP client.
-func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
- return client2Prompts.Get(clientName)
-}
-
// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
c, err := getOrRenewClient(ctx, clientName)
@@ -56,6 +47,30 @@ func GetPromptMessages(ctx context.Context, clientName, promptName string, args
return messages, nil
}
+// RefreshPrompts gets the updated list of prompts from the MCP and updates the
+// global state.
+func RefreshPrompts(ctx context.Context, name string) {
+ session, ok := sessions.Get(name)
+ if !ok {
+ slog.Warn("refresh prompts: no session", "name", name)
+ return
+ }
+
+ prompts, err := getPrompts(ctx, session)
+ if err != nil {
+ updateState(name, StateError, err, nil, Counts{})
+ return
+ }
+
+ updatePrompts(name, prompts)
+
+ prev, _ := states.Get(name)
+ updateState(name, StateConnected, nil, session, Counts{
+ Prompts: len(prompts),
+ Tools: prev.Counts.Tools,
+ })
+}
+
func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
if c.InitializeResult().Capabilities.Prompts == nil {
return nil, nil
@@ -70,11 +85,11 @@ func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
func updatePrompts(mcpName string, prompts []*Prompt) {
if len(prompts) == 0 {
- client2Prompts.Del(mcpName)
+ clientPrompts.Del(mcpName)
} else {
- client2Prompts.Set(mcpName, prompts)
+ clientPrompts.Set(mcpName, prompts)
}
- for mcpName, prompts := range client2Prompts.Seq2() {
+ for mcpName, prompts := range clientPrompts.Seq2() {
for _, p := range prompts {
key := mcpName + ":" + p.Name
allPrompts.Set(key, p)
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"iter"
+ "log/slog"
"strings"
"github.com/charmbracelet/crush/internal/csync"
@@ -14,12 +15,12 @@ import (
type Tool = mcp.Tool
var (
- allTools = csync.NewMap[string, *Tool]()
- client2Tools = csync.NewMap[string, []*Tool]()
+ allTools = csync.NewMap[string, *Tool]()
+ clientTools = csync.NewMap[string, []*Tool]()
)
-// GetTools returns all available MCP tools.
-func GetTools() iter.Seq2[string, *Tool] {
+// Tools returns all available MCP tools.
+func Tools() iter.Seq2[string, *Tool] {
return allTools.Seq2()
}
@@ -53,6 +54,30 @@ func RunTool(ctx context.Context, name, toolName string, input string) (string,
return strings.Join(output, "\n"), nil
}
+// RefreshTools gets the updated list of tools from the MCP and updates the
+// global state.
+func RefreshTools(ctx context.Context, name string) {
+ session, ok := sessions.Get(name)
+ if !ok {
+ slog.Warn("refresh tools: no session", "name", name)
+ return
+ }
+
+ tools, err := getTools(ctx, session)
+ if err != nil {
+ updateState(name, StateError, err, nil, Counts{})
+ return
+ }
+
+ updateTools(name, tools)
+
+ prev, _ := states.Get(name)
+ updateState(name, StateConnected, nil, session, Counts{
+ Tools: len(tools),
+ Prompts: prev.Counts.Prompts,
+ })
+}
+
func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
if session.InitializeResult().Capabilities.Tools == nil {
return nil, nil
@@ -65,13 +90,13 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error)
}
// updateTools updates the global mcpTools and mcpClient2Tools maps
-func updateTools(mcpName string, tools []*Tool) {
+func updateTools(name string, tools []*Tool) {
if len(tools) == 0 {
- client2Tools.Del(mcpName)
+ clientTools.Del(name)
} else {
- client2Tools.Set(mcpName, tools)
+ clientTools.Set(name, tools)
}
- for name, tools := range client2Tools.Seq2() {
+ for name, tools := range clientTools.Seq2() {
for _, t := range tools {
allTools.Set(name, t)
}
@@ -97,7 +97,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
}()
// cleanup database upon app shutdown
- app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
+ app.cleanupFuncs = append(app.cleanupFuncs, conn.Close, mcp.Close)
// TODO: remove the concept of agent config, most likely.
if !cfg.IsConfigured() {
@@ -327,9 +327,6 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
slog.Error("Failed to create coder agent", "err", err)
return err
}
-
- // Add MCP client cleanup to shutdown process
- app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close)
return nil
}
@@ -223,7 +223,7 @@ type CommandRunCustomMsg struct {
func loadMCPPrompts() []Command {
var commands []Command
- for key, prompt := range mcp.GetPrompts() {
+ for key, prompt := range mcp.Prompts() {
clientName, promptName, ok := strings.Cut(key, ":")
if !ok {
slog.Warn("prompt not found", "key", key)
@@ -10,6 +10,7 @@ import (
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/event"
@@ -140,6 +141,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.completions.Update(msg)
return a, a.handleWindowResize(msg.Width, msg.Height)
+ case pubsub.Event[mcp.Event]:
+ switch msg.Payload.Type {
+ case mcp.EventPromptsListChanged:
+
+ case mcp.EventToolsListChanged:
+ return a, a.handleMCPToolsEvent(context.Background(), msg.Payload.Name)
+ }
+
// Completions messages
case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg,
completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg:
@@ -618,6 +627,20 @@ func (a *appModel) View() tea.View {
return view
}
+func (a *appModel) handleMCPPromptsEvent(ctx context.Context, name string) tea.Cmd {
+ return func() tea.Msg {
+ mcp.RefreshPrompts(ctx, name)
+ return nil
+ }
+}
+
+func (a *appModel) handleMCPToolsEvent(ctx context.Context, name string) tea.Cmd {
+ return func() tea.Msg {
+ mcp.RefreshTools(ctx, name)
+ return nil
+ }
+}
+
// New creates and initializes a new TUI application model.
func New(app *app.App) *appModel {
chatPage := chat.New(app)