diff --git a/Taskfile.yaml b/Taskfile.yaml index 13c171ed2e67faa9aa87c6f9f7d0ec3b7018f382..1c4225158fc21508e8dccac8d6f47610f7d81faf 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -38,7 +38,7 @@ tasks: run: desc: Run build cmds: - - go run . + - go run . {{.CLI_ARGS}} test: desc: Run tests @@ -104,6 +104,6 @@ tasks: - git push origin --tags fetch-tags: - cmds: + cmds: - git tag -d nightly || true - git fetch --tags diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 8bb5df5ab51a585843734ca9cf750428513288d7..d58e5a4822dc6ef97b58ab68a9dfde320bf6ad10 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -1102,24 +1102,32 @@ func (a *agent) setupEvents(ctx context.Context) { return } switch event.Payload.Type { - case MCPEventToolsListChanged: + case MCPEventToolsListChanged, MCPEventPromptsListChanged: name := event.Payload.Name c, ok := mcpClients.Get(name) if !ok { - slog.Warn("MCP client not found for tools update", "name", name) + slog.Warn("MCP client not found for tools/prompts update", "name", name) continue } cfg := config.Get() tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) if err != nil { slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, 0, 0) + _ = c.Close() + continue + } + prompts, err := getPrompts(ctx, c) + if err != nil { + slog.Error("error listing prompts", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0, 0) _ = c.Close() continue } updateMcpTools(name, tools) + updateMcpPrompts(name, prompts) a.mcpTools.Reset(maps.Collect(mcpTools.Seq2())) - updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len()) + updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len(), len(prompts)) default: continue } diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index ae81a306b7981713b9faefc6cde860b640a2b5cf..aedf193b39624175cf99893636618d7c4b6cfb8b 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -53,17 +53,19 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" - MCPEventToolsListChanged MCPEventType = "tools_list_changed" + MCPEventStateChanged MCPEventType = "state_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" + MCPEventPromptsListChanged MCPEventType = "prompts_list_changed" ) // MCPEvent represents an event in the MCP system type MCPEvent struct { - Type MCPEventType - Name string - State MCPState - Error error - ToolCount int + Type MCPEventType + Name string + State MCPState + Error error + ToolCount int + PromptCount int } // MCPClientInfo holds information about an MCP client's state @@ -73,16 +75,19 @@ type MCPClientInfo struct { Error error Client *mcp.ClientSession ToolCount int + PromptCount int ConnectedAt time.Time } var ( - mcpToolsOnce sync.Once - mcpTools = csync.NewMap[string, tools.BaseTool]() - mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() - mcpClients = csync.NewMap[string, *mcp.ClientSession]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpToolsOnce sync.Once + mcpTools = csync.NewMap[string, tools.BaseTool]() + mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() + mcpClients = csync.NewMap[string, *mcp.ClientSession]() + mcpStates = csync.NewMap[string, MCPClientInfo]() + mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpPrompts = csync.NewMap[string, *mcp.Prompt]() + mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]() ) type McpTool struct { @@ -154,14 +159,14 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err if err == nil { return sess, nil } - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount, state.PromptCount) sess, err = createMCPSession(ctx, name, m, cfg.Resolver()) if err != nil { return nil, err } - updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount) + updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount, state.PromptCount) mcpClients.Set(name, sess) return sess, nil } @@ -191,6 +196,9 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes } func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) { + if c.InitializeResult().Capabilities.Tools == nil { + return nil, nil + } result, err := c.ListTools(ctx, &mcp.ListToolsParams{}) if err != nil { return nil, err @@ -223,30 +231,33 @@ func GetMCPState(name string) (MCPClientInfo, bool) { } // updateMCPState updates the state of an MCP client and publishes an event -func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) { +func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount, promptCount int) { info := MCPClientInfo{ - Name: name, - State: state, - Error: err, - Client: client, - ToolCount: toolCount, + Name: name, + State: state, + Error: err, + Client: client, + ToolCount: toolCount, + PromptCount: promptCount, } switch state { case MCPStateConnected: info.ConnectedAt = time.Now() case MCPStateError: updateMcpTools(name, nil) + updateMcpPrompts(name, nil) mcpClients.Del(name) } mcpStates.Set(name, info) // Publish state change event mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventStateChanged, - Name: name, - State: state, - Error: err, - ToolCount: toolCount, + Type: MCPEventStateChanged, + Name: name, + State: state, + Error: err, + ToolCount: toolCount, + PromptCount: promptCount, }) } @@ -267,13 +278,13 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con // Initialize states for all configured MCPs for name, m := range cfg.MCP { if m.Disabled { - updateMCPState(name, MCPStateDisabled, nil, nil, 0) + updateMCPState(name, MCPStateDisabled, nil, nil, 0, 0) slog.Debug("skipping disabled mcp", "name", name) continue } // Set initial starting state - updateMCPState(name, MCPStateStarting, nil, nil, 0) + updateMCPState(name, MCPStateStarting, nil, nil, 0, 0) wg.Add(1) go func(name string, m config.MCPConfig) { @@ -289,7 +300,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con default: err = fmt.Errorf("panic: %v", v) } - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, 0, 0) slog.Error("panic in mcp client initialization", "error", err, "name", name) } }() @@ -307,14 +318,23 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) if err != nil { slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, 0, 0) + c.Close() + return + } + + prompts, err := getPrompts(ctx, c) + if err != nil { + slog.Error("error listing prompts", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0, 0) c.Close() return } updateMcpTools(name, tools) + updateMcpPrompts(name, prompts) mcpClients.Set(name, c) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) + updateMCPState(name, MCPStateConnected, nil, c, len(tools), len(prompts)) }(name, m) } wg.Wait() @@ -337,7 +357,7 @@ func updateMcpTools(mcpName string, tools []tools.BaseTool) { func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) { transport, err := createMCPTransport(m, resolver) if err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, 0, 0) slog.Error("error creating mcp client", "error", err, "name", name) return nil, err } @@ -355,6 +375,12 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso Name: name, }) }, + PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) { + mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ + Type: MCPEventPromptsListChanged, + Name: name, + }) + }, KeepAlive: time.Minute * 10, }, ) @@ -365,7 +391,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso session, err := client.Connect(mcpCtx, transport, nil) if err != nil { - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0, 0) slog.Error("error starting mcp client", "error", err, "name", name) _ = session.Close() cancel() @@ -444,3 +470,57 @@ func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error func mcpTimeout(m config.MCPConfig) time.Duration { return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second } + +func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) { + if c.InitializeResult().Capabilities.Prompts == nil { + return nil, nil + } + result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{}) + if err != nil { + return nil, err + } + return result.Prompts, nil +} + +// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps. +func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) { + if len(prompts) == 0 { + mcpClient2Prompts.Del(mcpName) + } else { + mcpClient2Prompts.Set(mcpName, prompts) + } + for clientName, prompts := range mcpClient2Prompts.Seq2() { + for _, p := range prompts { + key := clientName + ":" + p.Name + mcpPrompts.Set(key, p) + } + } +} + +// GetMCPPrompts returns all available MCP prompts. +func GetMCPPrompts() map[string]*mcp.Prompt { + return maps.Collect(mcpPrompts.Seq2()) +} + +// GetMCPPrompt returns a specific MCP prompt by name. +func GetMCPPrompt(name string) (*mcp.Prompt, bool) { + return mcpPrompts.Get(name) +} + +// GetMCPPromptsByClient returns all prompts for a specific MCP client. +func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) { + return mcpClient2Prompts.Get(clientName) +} + +// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments. +func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) { + c, err := getOrRenewClient(ctx, clientName) + if err != nil { + return nil, err + } + + return c.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: promptName, + Arguments: args, + }) +} diff --git a/internal/tui/components/core/core.go b/internal/tui/components/core/core.go index 18de56b17f08e4513bde34fe9fef7aaf4e08c09f..80c28ba1e11c4ddeb7e6da1f4802577d23e8b4dc 100644 --- a/internal/tui/components/core/core.go +++ b/internal/tui/components/core/core.go @@ -110,14 +110,17 @@ func Status(opts StatusOpts, width int) string { extraContentWidth += 1 } description = ansi.Truncate(description, width-lipgloss.Width(icon)-lipgloss.Width(title)-2-extraContentWidth, "…") + description = t.S().Base.Foreground(descriptionColor).Render(description) } - description = t.S().Base.Foreground(descriptionColor).Render(description) content := []string{} if icon != "" { content = append(content, icon) } - content = append(content, title, description) + content = append(content, title) + if description != "" { + content = append(content, description) + } if opts.ExtraContent != "" { content = append(content, opts.ExtraContent) } diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 664158fc392a87d8a7725bfa964748f7ef4f8e67..5e9d1fa1b2e47b8284bbb65cd0a48d3230049a08 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -1,7 +1,9 @@ package commands import ( + "context" "os" + "strings" "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" @@ -10,7 +12,9 @@ import ( "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/llm/prompt" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -28,6 +32,7 @@ const ( const ( SystemCommands int = iota UserCommands + MCPPrompts ) type listModel = list.FilterableList[list.CompletionItem[Command]] @@ -54,9 +59,12 @@ type commandDialogCmp struct { commandList listModel keyMap CommandsDialogKeyMap help help.Model - commandType int // SystemCommands or UserCommands + commandType int // SystemCommands, UserCommands, or MCPPrompts userCommands []Command // User-defined commands + mcpPrompts []Command // MCP prompts sessionID string // Current session ID + ctx context.Context + cancel context.CancelFunc } type ( @@ -113,7 +121,26 @@ func (c *commandDialogCmp) Init() tea.Cmd { return util.ReportError(err) } c.userCommands = commands - return c.SetCommandType(c.commandType) + c.mcpPrompts = LoadMCPPrompts() + + // Subscribe to MCP events + c.ctx, c.cancel = context.WithCancel(context.Background()) + return tea.Batch( + c.SetCommandType(c.commandType), + c.subscribeMCPEvents(), + ) +} + +func (c *commandDialogCmp) subscribeMCPEvents() tea.Cmd { + return func() tea.Msg { + ch := agent.SubscribeMCPEvents(c.ctx) + for event := range ch { + if event.Type == pubsub.UpdatedEvent { + return event + } + } + return nil + } } func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -125,6 +152,19 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.SetCommandType(c.commandType), c.commandList.SetSize(c.listWidth(), c.listHeight()), ) + case pubsub.Event[agent.MCPEvent]: + // Reload MCP prompts when MCP state changes + if msg.Type == pubsub.UpdatedEvent { + c.mcpPrompts = LoadMCPPrompts() + // If we're currently viewing MCP prompts, refresh the list + if c.commandType == MCPPrompts { + return c, tea.Batch( + c.SetCommandType(MCPPrompts), + c.subscribeMCPEvents(), + ) + } + return c, c.subscribeMCPEvents() + } case tea.KeyPressMsg: switch { case key.Matches(msg, c.keyMap.Select): @@ -133,21 +173,38 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return c, nil // No item selected, do nothing } command := (*selectedItem).Value() + if c.cancel != nil { + c.cancel() + } return c, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), command.Handler(command), ) case key.Matches(msg, c.keyMap.Tab): - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && len(c.mcpPrompts) == 0 { return c, nil } - // Toggle command type between System and User commands - if c.commandType == SystemCommands { - return c, c.SetCommandType(UserCommands) - } else { - return c, c.SetCommandType(SystemCommands) + // Cycle through command types: System -> User -> MCP -> System + nextType := (c.commandType + 1) % 3 + // Skip empty types + for { + if nextType == UserCommands && len(c.userCommands) == 0 { + nextType = (nextType + 1) % 3 + } else if nextType == MCPPrompts && len(c.mcpPrompts) == 0 { + nextType = (nextType + 1) % 3 + } else { + break + } + // Prevent infinite loop + if nextType == c.commandType { + return c, nil + } } + return c, c.SetCommandType(nextType) case key.Matches(msg, c.keyMap.Close): + if c.cancel != nil { + c.cancel() + } return c, util.CmdHandler(dialogs.CloseDialogMsg{}) default: u, cmd := c.commandList.Update(msg) @@ -164,7 +221,7 @@ func (c *commandDialogCmp) View() string { radio := c.commandTypeRadio() header := t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-lipgloss.Width(radio)-5) + " " + radio) - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && len(c.mcpPrompts) == 0 { header = t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-4)) } content := lipgloss.JoinVertical( @@ -190,13 +247,25 @@ func (c *commandDialogCmp) Cursor() *tea.Cursor { func (c *commandDialogCmp) commandTypeRadio() string { t := styles.CurrentTheme() - choices := []string{"System", "User"} + choices := []string{"System", "User", "MCP"} iconSelected := "◉" iconUnselected := "○" - if c.commandType == SystemCommands { - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) + + icons := make([]string, 3) + for i := range icons { + if i == c.commandType { + icons[i] = iconSelected + } else { + icons[i] = iconUnselected + } + } + + parts := make([]string, 0, 6) + for i, choice := range choices { + parts = append(parts, icons[i]+" "+choice) } - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) + + return t.S().Base.Foreground(t.FgHalfMuted).Render(strings.Join(parts, " ")) } func (c *commandDialogCmp) listWidth() int { @@ -207,10 +276,13 @@ func (c *commandDialogCmp) SetCommandType(commandType int) tea.Cmd { c.commandType = commandType var commands []Command - if c.commandType == SystemCommands { + switch c.commandType { + case SystemCommands: commands = c.defaultCommands() - } else { + case UserCommands: commands = c.userCommands + case MCPPrompts: + commands = c.mcpPrompts } commandItems := []list.CompletionItem[Command]{} diff --git a/internal/tui/components/dialogs/commands/keys.go b/internal/tui/components/dialogs/commands/keys.go index 7b79a29c28a024154a3b4d8c763969585409fd00..32d34d3c26360f040ea8ecd7e9c1fa3c4d2d7a5a 100644 --- a/internal/tui/components/dialogs/commands/keys.go +++ b/internal/tui/components/dialogs/commands/keys.go @@ -76,6 +76,7 @@ type ArgumentsDialogKeyMap struct { Confirm key.Binding Next key.Binding Previous key.Binding + Cancel key.Binding } func DefaultArgumentsDialogKeyMap() ArgumentsDialogKeyMap { @@ -93,6 +94,10 @@ func DefaultArgumentsDialogKeyMap() ArgumentsDialogKeyMap { key.WithKeys("shift+tab", "up"), key.WithHelp("shift+tab/↑", "previous"), ), + Cancel: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "cancel"), + ), } } diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 74d9c7e4baee2e2d19f8baca914942f0c0d34cd3..df5caa5e2350ec1ab9c7177fb7da83f7661acb66 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -1,6 +1,7 @@ package commands import ( + "context" "fmt" "io/fs" "os" @@ -9,14 +10,19 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/util" ) const ( UserCommandPrefix = "user:" ProjectCommandPrefix = "project:" + MCPPromptPrefix = "mcp:" ) var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`) @@ -201,3 +207,80 @@ func isMarkdownFile(name string) bool { type CommandRunCustomMsg struct { Content string } + +func LoadMCPPrompts() []Command { + prompts := agent.GetMCPPrompts() + commands := make([]Command, 0, len(prompts)) + + for key, prompt := range prompts { + p := prompt + // key format is "clientName:promptName" + parts := strings.SplitN(key, ":", 2) + if len(parts) != 2 { + continue + } + clientName, promptName := parts[0], parts[1] + + displayName := promptName + if p.Title != "" { + displayName = p.Title + } + + commands = append(commands, Command{ + ID: MCPPromptPrefix + key, + Title: displayName, + Description: fmt.Sprintf("[%s] %s", clientName, p.Description), + Handler: createMCPPromptHandler(key, promptName, p), + }) + } + + return commands +} + +func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd { + return func(cmd Command) tea.Cmd { + if len(prompt.Arguments) == 0 { + return executeMCPPromptWithoutArgs(key, promptName) + } + return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{ + PromptID: cmd.ID, + PromptName: promptName, + }) + } +} + +func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd { + return func() tea.Msg { + // key format is "clientName:promptName" + parts := strings.SplitN(key, ":", 2) + if len(parts) != 2 { + return util.ReportError(fmt.Errorf("invalid prompt key: %s", key)) + } + clientName := parts[0] + + ctx := context.Background() + result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil) + if err != nil { + return util.ReportError(err) + } + + var content strings.Builder + for _, msg := range result.Messages { + if msg.Role == "user" { + if textContent, ok := msg.Content.(*mcp.TextContent); ok { + content.WriteString(textContent.Text) + content.WriteString("\n") + } + } + } + + return chat.SendMsg{ + Text: content.String(), + } + } +} + +type ShowMCPPromptArgumentsDialogMsg struct { + PromptID string + PromptName string +} diff --git a/internal/tui/components/dialogs/commands/mcp_arguments.go b/internal/tui/components/dialogs/commands/mcp_arguments.go new file mode 100644 index 0000000000000000000000000000000000000000..4fd042f38059e57d7c51fa26038f4dcb343ec2ac --- /dev/null +++ b/internal/tui/components/dialogs/commands/mcp_arguments.go @@ -0,0 +1,262 @@ +package commands + +import ( + "cmp" + "context" + "fmt" + "log/slog" + "strings" + + "github.com/charmbracelet/bubbles/v2/help" + "github.com/charmbracelet/bubbles/v2/key" + "github.com/charmbracelet/bubbles/v2/textinput" + tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/lipgloss/v2" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/tui/components/chat" + "github.com/charmbracelet/crush/internal/tui/components/dialogs" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" +) + +const mcpArgumentsDialogID dialogs.DialogID = "mcp_arguments" + +type MCPPromptArgumentsDialog interface { + dialogs.DialogModel +} + +type mcpPromptArgumentsDialogCmp struct { + wWidth, wHeight int + width, height int + selected int + inputs []textinput.Model + keys ArgumentsDialogKeyMap + id string + prompt *mcp.Prompt + help help.Model +} + +func NewMCPPromptArgumentsDialog(id, name string) MCPPromptArgumentsDialog { + id = strings.TrimPrefix(id, MCPPromptPrefix) + prompt, ok := agent.GetMCPPrompt(id) + if !ok { + return nil + } + + t := styles.CurrentTheme() + inputs := make([]textinput.Model, len(prompt.Arguments)) + + for i, arg := range prompt.Arguments { + ti := textinput.New() + placeholder := fmt.Sprintf("Enter value for %s...", arg.Name) + if arg.Description != "" { + placeholder = arg.Description + } + ti.Placeholder = placeholder + ti.SetWidth(40) + ti.SetVirtualCursor(false) + ti.Prompt = "" + ti.SetStyles(t.S().TextInput) + + if i == 0 { + ti.Focus() + } else { + ti.Blur() + } + + inputs[i] = ti + } + + return &mcpPromptArgumentsDialogCmp{ + inputs: inputs, + keys: DefaultArgumentsDialogKeyMap(), + id: id, + prompt: prompt, + help: help.New(), + } +} + +func (c *mcpPromptArgumentsDialogCmp) Init() tea.Cmd { + return nil +} + +func (c *mcpPromptArgumentsDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + c.wWidth = msg.Width + c.wHeight = msg.Height + cmd := c.SetSize() + return c, cmd + case tea.KeyPressMsg: + switch { + case key.Matches(msg, c.keys.Cancel): + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) + case key.Matches(msg, c.keys.Confirm): + if c.selected == len(c.inputs)-1 { + args := make(map[string]string) + for i, arg := range c.prompt.Arguments { + value := c.inputs[i].Value() + args[arg.Name] = value + } + return c, tea.Sequence( + util.CmdHandler(dialogs.CloseDialogMsg{}), + c.executeMCPPrompt(args), + ) + } + c.inputs[c.selected].Blur() + c.selected++ + c.inputs[c.selected].Focus() + case key.Matches(msg, c.keys.Next): + c.inputs[c.selected].Blur() + c.selected = (c.selected + 1) % len(c.inputs) + c.inputs[c.selected].Focus() + case key.Matches(msg, c.keys.Previous): + c.inputs[c.selected].Blur() + c.selected = (c.selected - 1 + len(c.inputs)) % len(c.inputs) + c.inputs[c.selected].Focus() + default: + var cmd tea.Cmd + c.inputs[c.selected], cmd = c.inputs[c.selected].Update(msg) + return c, cmd + } + } + return c, nil +} + +func (c *mcpPromptArgumentsDialogCmp) executeMCPPrompt(args map[string]string) tea.Cmd { + return func() tea.Msg { + parts := strings.SplitN(c.id, ":", 2) + if len(parts) != 2 { + return util.ReportError(fmt.Errorf("invalid prompt ID: %s", c.id)) + } + clientName := parts[0] + + ctx := context.Background() + slog.Warn("AQUI", "name", c.prompt.Name, "id", c.id) + result, err := agent.GetMCPPromptContent(ctx, clientName, c.prompt.Name, args) + if err != nil { + return util.ReportError(err) + } + + var content strings.Builder + for _, msg := range result.Messages { + if msg.Role == "user" { + if textContent, ok := msg.Content.(*mcp.TextContent); ok { + content.WriteString(textContent.Text) + content.WriteString("\n") + } + } + } + + return chat.SendMsg{ + Text: content.String(), + } + } +} + +func (c *mcpPromptArgumentsDialogCmp) View() string { + t := styles.CurrentTheme() + baseStyle := t.S().Base + + title := lipgloss.NewStyle(). + Foreground(t.Primary). + Bold(true). + Padding(0, 1). + Render(cmp.Or(c.prompt.Title, c.prompt.Name)) + + promptName := t.S().Text. + Padding(0, 1). + Render(c.prompt.Description) + + if c.prompt == nil { + return baseStyle.Padding(1, 1, 0, 1). + Border(lipgloss.RoundedBorder()). + BorderForeground(t.BorderFocus). + Width(c.width). + Render(lipgloss.JoinVertical(lipgloss.Left, title, promptName, "", "Prompt not found")) + } + + inputFields := make([]string, len(c.inputs)) + for i, input := range c.inputs { + labelStyle := baseStyle.Padding(1, 1, 0, 1) + + if i == c.selected { + labelStyle = labelStyle.Foreground(t.FgBase).Bold(true) + } else { + labelStyle = labelStyle.Foreground(t.FgMuted) + } + + argName := c.prompt.Arguments[i].Name + if c.prompt.Arguments[i].Required { + argName += " *" + } + label := labelStyle.Render(argName + ":") + + field := t.S().Text. + Padding(0, 1). + Render(input.View()) + + inputFields[i] = lipgloss.JoinVertical(lipgloss.Left, label, field) + } + + elements := []string{title, promptName} + elements = append(elements, inputFields...) + + c.help.ShowAll = false + helpText := baseStyle.Padding(0, 1).Render(c.help.View(c.keys)) + elements = append(elements, "", helpText) + + content := lipgloss.JoinVertical(lipgloss.Left, elements...) + + return baseStyle.Padding(1, 1, 0, 1). + Border(lipgloss.RoundedBorder()). + BorderForeground(t.BorderFocus). + Width(c.width). + Render(content) +} + +func (c *mcpPromptArgumentsDialogCmp) Cursor() *tea.Cursor { + if len(c.inputs) == 0 { + return nil + } + cursor := c.inputs[c.selected].Cursor() + if cursor != nil { + cursor = c.moveCursor(cursor) + } + return cursor +} + +const ( + headerHeight = 3 + itemHeight = 3 + paddingHorizontal = 3 +) + +func (c *mcpPromptArgumentsDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { + row, col := c.Position() + offset := row + headerHeight + (1+c.selected)*itemHeight + cursor.Y += offset + cursor.X = cursor.X + col + paddingHorizontal + return cursor +} + +func (c *mcpPromptArgumentsDialogCmp) SetSize() tea.Cmd { + c.width = min(90, c.wWidth) + c.height = min(15, c.wHeight) + for i := range c.inputs { + c.inputs[i].SetWidth(c.width - (paddingHorizontal * 2)) + } + return nil +} + +func (c *mcpPromptArgumentsDialogCmp) Position() (int, int) { + row := (c.wHeight / 2) - (c.height / 2) + col := (c.wWidth / 2) - (c.width / 2) + return row, col +} + +func (c *mcpPromptArgumentsDialogCmp) ID() dialogs.DialogID { + return mcpArgumentsDialogID +} diff --git a/internal/tui/components/mcp/mcp.go b/internal/tui/components/mcp/mcp.go index d11826b77749ba65276b5336a5d88cdbc8552881..8a615ab6df856cd156addf05442779bc0f3c787b 100644 --- a/internal/tui/components/mcp/mcp.go +++ b/internal/tui/components/mcp/mcp.go @@ -2,6 +2,7 @@ package mcp import ( "fmt" + "strings" "github.com/charmbracelet/lipgloss/v2" @@ -55,8 +56,8 @@ func RenderMCPList(opts RenderOptions) []string { // Determine icon and color based on state icon := t.ItemOfflineIcon - description := l.MCP.Command - extraContent := "" + description := "" + extraContent := []string{} if state, exists := mcpStates[l.Name]; exists { switch state.State { @@ -68,7 +69,10 @@ func RenderMCPList(opts RenderOptions) []string { case agent.MCPStateConnected: icon = t.ItemOnlineIcon if state.ToolCount > 0 { - extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount)) + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount))) + } + if state.PromptCount > 0 { + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", state.PromptCount))) } case agent.MCPStateError: icon = t.ItemErrorIcon @@ -88,7 +92,7 @@ func RenderMCPList(opts RenderOptions) []string { Icon: icon.String(), Title: l.Name, Description: description, - ExtraContent: extraContent, + ExtraContent: strings.Join(extraContent, " "), }, opts.MaxWidth, ), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 74d82e15514c70ee96b507a01b8f611d3ade6a4d..04e640141d91f8e48778bf39ca3f3f4e39102902 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -147,6 +147,17 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { ), }, ) + case commands.ShowMCPPromptArgumentsDialogMsg: + dialog := commands.NewMCPPromptArgumentsDialog(msg.PromptID, msg.PromptName) + if dialog == nil { + util.ReportWarn(fmt.Sprintf("Prompt %s not found", msg.PromptName)) + return a, nil + } + return a, util.CmdHandler( + dialogs.OpenDialogMsg{ + Model: dialog, + }, + ) // Page change messages case page.PageChangeMsg: return a, a.moveToPage(msg.ID)