1package mcp
2
3import (
4 "context"
5 "iter"
6
7 "github.com/charmbracelet/crush/internal/csync"
8 "github.com/modelcontextprotocol/go-sdk/mcp"
9)
10
11type Prompt = mcp.Prompt
12
13var (
14 allPrompts = csync.NewMap[string, *Prompt]()
15 client2Prompts = csync.NewMap[string, []*Prompt]()
16)
17
18// GetPrompts returns all available MCP prompts.
19func GetPrompts() iter.Seq2[string, *Prompt] {
20 return allPrompts.Seq2()
21}
22
23// GetPrompt returns a specific MCP prompt by name.
24func GetPrompt(name string) (*Prompt, bool) {
25 return allPrompts.Get(name)
26}
27
28// GetPromptsByClient returns all prompts for a specific MCP client.
29func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
30 return client2Prompts.Get(clientName)
31}
32
33// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
34func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
35 c, err := getOrRenewClient(ctx, clientName)
36 if err != nil {
37 return nil, err
38 }
39 result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
40 Name: promptName,
41 Arguments: args,
42 })
43 if err != nil {
44 return nil, err
45 }
46
47 var messages []string
48 for _, msg := range result.Messages {
49 if msg.Role != "user" {
50 continue
51 }
52 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
53 messages = append(messages, textContent.Text)
54 }
55 }
56 return messages, nil
57}
58
59func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
60 if c.InitializeResult().Capabilities.Prompts == nil {
61 return nil, nil
62 }
63 result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
64 if err != nil {
65 return nil, err
66 }
67 return result.Prompts, nil
68}
69
70// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
71func updatePrompts(mcpName string, prompts []*Prompt) {
72 if len(prompts) == 0 {
73 client2Prompts.Del(mcpName)
74 } else {
75 client2Prompts.Set(mcpName, prompts)
76 }
77 for mcpName, prompts := range client2Prompts.Seq2() {
78 for _, p := range prompts {
79 key := mcpName + ":" + p.Name
80 allPrompts.Set(key, p)
81 }
82 }
83}