1package mcp
2
3import (
4 "context"
5 "iter"
6 "log/slog"
7
8 "git.secluded.site/crush/internal/csync"
9 "github.com/modelcontextprotocol/go-sdk/mcp"
10)
11
12type Prompt = mcp.Prompt
13
14var allPrompts = csync.NewMap[string, []*Prompt]()
15
16// Prompts returns all available MCP prompts.
17func Prompts() iter.Seq2[string, []*Prompt] {
18 return allPrompts.Seq2()
19}
20
21// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
22func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
23 c, err := getOrRenewClient(ctx, clientName)
24 if err != nil {
25 return nil, err
26 }
27 result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
28 Name: promptName,
29 Arguments: args,
30 })
31 if err != nil {
32 return nil, err
33 }
34
35 var messages []string
36 for _, msg := range result.Messages {
37 if msg.Role != "user" {
38 continue
39 }
40 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
41 messages = append(messages, textContent.Text)
42 }
43 }
44 return messages, nil
45}
46
47// RefreshPrompts gets the updated list of prompts from the MCP and updates the
48// global state.
49func RefreshPrompts(ctx context.Context, name string) {
50 session, ok := sessions.Get(name)
51 if !ok {
52 slog.Warn("refresh prompts: no session", "name", name)
53 return
54 }
55
56 prompts, err := getPrompts(ctx, session)
57 if err != nil {
58 updateState(name, StateError, err, nil, Counts{})
59 return
60 }
61
62 updatePrompts(name, prompts)
63
64 prev, _ := states.Get(name)
65 prev.Counts.Prompts = len(prompts)
66 updateState(name, StateConnected, nil, session, prev.Counts)
67}
68
69func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
70 if c.InitializeResult().Capabilities.Prompts == nil {
71 return nil, nil
72 }
73 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
74 if err != nil {
75 return nil, err
76 }
77 return result.Prompts, nil
78}
79
80// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
81func updatePrompts(mcpName string, prompts []*Prompt) {
82 if len(prompts) == 0 {
83 allPrompts.Del(mcpName)
84 return
85 }
86 allPrompts.Set(mcpName, prompts)
87}