prompts.go

 1package mcp
 2
 3import (
 4	"context"
 5	"iter"
 6
 7	"github.com/charmbracelet/crush/internal/csync"
 8	"github.com/modelcontextprotocol/go-sdk/mcp"
 9)
10
11type Prompt = mcp.Prompt
12
13var (
14	allPrompts     = csync.NewMap[string, *Prompt]()
15	client2Prompts = csync.NewMap[string, []*Prompt]()
16)
17
18// GetPrompts returns all available MCP prompts.
19func GetPrompts() iter.Seq2[string, *Prompt] {
20	return allPrompts.Seq2()
21}
22
23// GetPrompt returns a specific MCP prompt by name.
24func GetPrompt(name string) (*Prompt, bool) {
25	return allPrompts.Get(name)
26}
27
28// GetPromptsByClient returns all prompts for a specific MCP client.
29func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
30	return client2Prompts.Get(clientName)
31}
32
33// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
34func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
35	c, err := getOrRenewClient(ctx, clientName)
36	if err != nil {
37		return nil, err
38	}
39	result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
40		Name:      promptName,
41		Arguments: args,
42	})
43	if err != nil {
44		return nil, err
45	}
46
47	var messages []string
48	for _, msg := range result.Messages {
49		if msg.Role != "user" {
50			continue
51		}
52		if textContent, ok := msg.Content.(*mcp.TextContent); ok {
53			messages = append(messages, textContent.Text)
54		}
55	}
56	return messages, nil
57}
58
59func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
60	if c.InitializeResult().Capabilities.Prompts == nil {
61		return nil, nil
62	}
63	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
64	if err != nil {
65		return nil, err
66	}
67	return result.Prompts, nil
68}
69
70// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
71func updatePrompts(mcpName string, prompts []*Prompt) {
72	if len(prompts) == 0 {
73		client2Prompts.Del(mcpName)
74	} else {
75		client2Prompts.Set(mcpName, prompts)
76	}
77	for mcpName, prompts := range client2Prompts.Seq2() {
78		for _, p := range prompts {
79			key := mcpName + ":" + p.Name
80			allPrompts.Set(key, p)
81		}
82	}
83}