diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index 22afb417ef5c4fb2b01046ec4bf3fe90826d371e..ff3be182e87cb2bff82be3041b62aa4f126c0be2 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -12,7 +12,7 @@ import ( // GetMCPTools gets all the currently available MCP tools. func GetMCPTools(permissions permission.Service, wd string) []*Tool { var result []*Tool - for name, tool := range mcp.GetMCPTools() { + for name, tool := range mcp.Tools() { result = append(result, &Tool{ mcpName: name, tool: tool, diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 10cfdc87d3bfb78b7d723a13b0182540bd8bc50f..961ce58071cb90808db52654629fccd7b1030b8c 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -281,6 +281,9 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve Name: name, }) }, + LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) { + slog.Info("mcp log", "name", name, "data", req.Params.Data) + }, KeepAlive: time.Minute * 10, }, ) diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go index 4439a6d215efb8d8645dcc0e197a05c822672829..e168b71f2137963c4c7e428e21edd11d10ac6a9d 100644 --- a/internal/agent/tools/mcp/prompts.go +++ b/internal/agent/tools/mcp/prompts.go @@ -3,6 +3,7 @@ package mcp import ( "context" "iter" + "log/slog" "github.com/charmbracelet/crush/internal/csync" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -11,25 +12,15 @@ import ( type Prompt = mcp.Prompt var ( - allPrompts = csync.NewMap[string, *Prompt]() - client2Prompts = csync.NewMap[string, []*Prompt]() + allPrompts = csync.NewMap[string, *Prompt]() + clientPrompts = csync.NewMap[string, []*Prompt]() ) -// GetPrompts returns all available MCP prompts. -func GetPrompts() iter.Seq2[string, *Prompt] { +// Prompts returns all available MCP prompts. +func Prompts() iter.Seq2[string, *Prompt] { return allPrompts.Seq2() } -// GetPrompt returns a specific MCP prompt by name. -func GetPrompt(name string) (*Prompt, bool) { - return allPrompts.Get(name) -} - -// GetPromptsByClient returns all prompts for a specific MCP client. -func GetPromptsByClient(clientName string) ([]*Prompt, bool) { - return client2Prompts.Get(clientName) -} - // GetPromptMessages retrieves the content of an MCP prompt with the given arguments. func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) { c, err := getOrRenewClient(ctx, clientName) @@ -56,6 +47,30 @@ func GetPromptMessages(ctx context.Context, clientName, promptName string, args return messages, nil } +// RefreshPrompts gets the updated list of prompts from the MCP and updates the +// global state. +func RefreshPrompts(ctx context.Context, name string) { + session, ok := sessions.Get(name) + if !ok { + slog.Warn("refresh prompts: no session", "name", name) + return + } + + prompts, err := getPrompts(ctx, session) + if err != nil { + updateState(name, StateError, err, nil, Counts{}) + return + } + + updatePrompts(name, prompts) + + prev, _ := states.Get(name) + updateState(name, StateConnected, nil, session, Counts{ + Prompts: len(prompts), + Tools: prev.Counts.Tools, + }) +} + func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) { if c.InitializeResult().Capabilities.Prompts == nil { return nil, nil @@ -70,11 +85,11 @@ func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) { // updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps func updatePrompts(mcpName string, prompts []*Prompt) { if len(prompts) == 0 { - client2Prompts.Del(mcpName) + clientPrompts.Del(mcpName) } else { - client2Prompts.Set(mcpName, prompts) + clientPrompts.Set(mcpName, prompts) } - for mcpName, prompts := range client2Prompts.Seq2() { + for mcpName, prompts := range clientPrompts.Seq2() { for _, p := range prompts { key := mcpName + ":" + p.Name allPrompts.Set(key, p) diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index a30b9de57b0e5ac35c41a666c105522121164f79..e39e999352373f7c2a140600a84c7ed70571e1b6 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "iter" + "log/slog" "strings" "github.com/charmbracelet/crush/internal/csync" @@ -14,12 +15,12 @@ import ( type Tool = mcp.Tool var ( - allTools = csync.NewMap[string, *Tool]() - client2Tools = csync.NewMap[string, []*Tool]() + allTools = csync.NewMap[string, *Tool]() + clientTools = csync.NewMap[string, []*Tool]() ) -// GetTools returns all available MCP tools. -func GetTools() iter.Seq2[string, *Tool] { +// Tools returns all available MCP tools. +func Tools() iter.Seq2[string, *Tool] { return allTools.Seq2() } @@ -53,6 +54,30 @@ func RunTool(ctx context.Context, name, toolName string, input string) (string, return strings.Join(output, "\n"), nil } +// RefreshTools gets the updated list of tools from the MCP and updates the +// global state. +func RefreshTools(ctx context.Context, name string) { + session, ok := sessions.Get(name) + if !ok { + slog.Warn("refresh tools: no session", "name", name) + return + } + + tools, err := getTools(ctx, session) + if err != nil { + updateState(name, StateError, err, nil, Counts{}) + return + } + + updateTools(name, tools) + + prev, _ := states.Get(name) + updateState(name, StateConnected, nil, session, Counts{ + Tools: len(tools), + Prompts: prev.Counts.Prompts, + }) +} + func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) { if session.InitializeResult().Capabilities.Tools == nil { return nil, nil @@ -65,13 +90,13 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) } // updateTools updates the global mcpTools and mcpClient2Tools maps -func updateTools(mcpName string, tools []*Tool) { +func updateTools(name string, tools []*Tool) { if len(tools) == 0 { - client2Tools.Del(mcpName) + clientTools.Del(name) } else { - client2Tools.Set(mcpName, tools) + clientTools.Set(name, tools) } - for name, tools := range client2Tools.Seq2() { + for name, tools := range clientTools.Seq2() { for _, t := range tools { allTools.Set(name, t) } diff --git a/internal/app/app.go b/internal/app/app.go index fe0e2957dede4b410ab1db76e85c3bbc4bc2a49b..dc0d26a83a1d20aa8dac220a2b8451089a1a0e25 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -97,7 +97,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { }() // cleanup database upon app shutdown - app.cleanupFuncs = append(app.cleanupFuncs, conn.Close) + app.cleanupFuncs = append(app.cleanupFuncs, conn.Close, mcp.Close) // TODO: remove the concept of agent config, most likely. if !cfg.IsConfigured() { @@ -327,9 +327,6 @@ func (app *App) InitCoderAgent(ctx context.Context) error { slog.Error("Failed to create coder agent", "err", err) return err } - - // Add MCP client cleanup to shutdown process - app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close) return nil } diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index ae52104b35e7614bb46bf6e30986fc90b43bb716..9e20aea7480ac58456fa22eb73057d9d6d1115e6 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -223,7 +223,7 @@ type CommandRunCustomMsg struct { func loadMCPPrompts() []Command { var commands []Command - for key, prompt := range mcp.GetPrompts() { + for key, prompt := range mcp.Prompts() { clientName, promptName, ok := strings.Cut(key, ":") if !ok { slog.Warn("prompt not found", "key", key) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index ea0ed6649e5cc6de57597c589dc9785fb441bb72..f40b3bef6853e58482929727b7d845b89f438cba 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/event" @@ -140,6 +141,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.completions.Update(msg) return a, a.handleWindowResize(msg.Width, msg.Height) + case pubsub.Event[mcp.Event]: + switch msg.Payload.Type { + case mcp.EventPromptsListChanged: + + case mcp.EventToolsListChanged: + return a, a.handleMCPToolsEvent(context.Background(), msg.Payload.Name) + } + // Completions messages case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg: @@ -618,6 +627,20 @@ func (a *appModel) View() tea.View { return view } +func (a *appModel) handleMCPPromptsEvent(ctx context.Context, name string) tea.Cmd { + return func() tea.Msg { + mcp.RefreshPrompts(ctx, name) + return nil + } +} + +func (a *appModel) handleMCPToolsEvent(ctx context.Context, name string) tea.Cmd { + return func() tea.Msg { + mcp.RefreshTools(ctx, name) + return nil + } +} + // New creates and initializes a new TUI application model. func New(app *app.App) *appModel { chatPage := chat.New(app)