1package mcp
2
3import (
4 "context"
5 "iter"
6 "log/slog"
7
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/csync"
10 "github.com/modelcontextprotocol/go-sdk/mcp"
11)
12
13type Prompt = mcp.Prompt
14
15var allPrompts = csync.NewMap[string, []*Prompt]()
16
17// Prompts returns all available MCP prompts.
18func Prompts() iter.Seq2[string, []*Prompt] {
19 return allPrompts.Seq2()
20}
21
22// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
23func GetPromptMessages(ctx context.Context, cfg *config.ConfigStore, clientName, promptName string, args map[string]string) ([]string, error) {
24 c, err := getOrRenewClient(ctx, cfg, clientName)
25 if err != nil {
26 return nil, err
27 }
28 result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
29 Name: promptName,
30 Arguments: args,
31 })
32 if err != nil {
33 return nil, err
34 }
35
36 var messages []string
37 for _, msg := range result.Messages {
38 if msg.Role != "user" {
39 continue
40 }
41 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
42 messages = append(messages, textContent.Text)
43 }
44 }
45 return messages, nil
46}
47
48// RefreshPrompts gets the updated list of prompts from the MCP and updates the
49// global state.
50func RefreshPrompts(ctx context.Context, name string) {
51 session, ok := sessions.Get(name)
52 if !ok {
53 slog.Warn("Refresh prompts: no session", "name", name)
54 return
55 }
56
57 prompts, err := getPrompts(ctx, session)
58 if err != nil {
59 updateState(name, StateError, err, nil, Counts{})
60 return
61 }
62
63 updatePrompts(name, prompts)
64
65 prev, _ := states.Get(name)
66 prev.Counts.Prompts = len(prompts)
67 updateState(name, StateConnected, nil, session, prev.Counts)
68}
69
70func getPrompts(ctx context.Context, c *ClientSession) ([]*Prompt, error) {
71 if c.InitializeResult().Capabilities.Prompts == nil {
72 return nil, nil
73 }
74 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
75 if err != nil {
76 return nil, err
77 }
78 return result.Prompts, nil
79}
80
81// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
82func updatePrompts(mcpName string, prompts []*Prompt) {
83 if len(prompts) == 0 {
84 allPrompts.Del(mcpName)
85 return
86 }
87 allPrompts.Set(mcpName, prompts)
88}