1package mcp
2
3import (
4 "context"
5 "iter"
6 "log/slog"
7
8 "github.com/charmbracelet/crush/internal/csync"
9 "github.com/modelcontextprotocol/go-sdk/mcp"
10)
11
12type Prompt = mcp.Prompt
13
14var (
15 allPrompts = csync.NewMap[string, *Prompt]()
16 clientPrompts = csync.NewMap[string, []*Prompt]()
17)
18
19// Prompts returns all available MCP prompts.
20func Prompts() iter.Seq2[string, *Prompt] {
21 return allPrompts.Seq2()
22}
23
24// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
25func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
26 c, err := getOrRenewClient(ctx, clientName)
27 if err != nil {
28 return nil, err
29 }
30 result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
31 Name: promptName,
32 Arguments: args,
33 })
34 if err != nil {
35 return nil, err
36 }
37
38 var messages []string
39 for _, msg := range result.Messages {
40 if msg.Role != "user" {
41 continue
42 }
43 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
44 messages = append(messages, textContent.Text)
45 }
46 }
47 return messages, nil
48}
49
50// RefreshPrompts gets the updated list of prompts from the MCP and updates the
51// global state.
52func RefreshPrompts(ctx context.Context, name string) {
53 session, ok := sessions.Get(name)
54 if !ok {
55 slog.Warn("refresh prompts: no session", "name", name)
56 return
57 }
58
59 prompts, err := getPrompts(ctx, session)
60 if err != nil {
61 updateState(name, StateError, err, nil, Counts{})
62 return
63 }
64
65 updatePrompts(name, prompts)
66
67 prev, _ := states.Get(name)
68 updateState(name, StateConnected, nil, session, Counts{
69 Prompts: len(prompts),
70 Tools: prev.Counts.Tools,
71 })
72}
73
74func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
75 if c.InitializeResult().Capabilities.Prompts == nil {
76 return nil, nil
77 }
78 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
79 if err != nil {
80 return nil, err
81 }
82 return result.Prompts, nil
83}
84
85// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
86func updatePrompts(mcpName string, prompts []*Prompt) {
87 if len(prompts) == 0 {
88 clientPrompts.Del(mcpName)
89 } else {
90 clientPrompts.Set(mcpName, prompts)
91 }
92 for mcpName, prompts := range clientPrompts.Seq2() {
93 for _, p := range prompts {
94 key := mcpName + ":" + p.Name
95 allPrompts.Set(key, p)
96 }
97 }
98}