prompts.go

 1package mcp
 2
 3import (
 4	"context"
 5	"iter"
 6	"log/slog"
 7
 8	"github.com/charmbracelet/crush/internal/config"
 9	"github.com/charmbracelet/crush/internal/csync"
10	"github.com/modelcontextprotocol/go-sdk/mcp"
11)
12
13type Prompt = mcp.Prompt
14
15var allPrompts = csync.NewMap[string, []*Prompt]()
16
17// Prompts returns all available MCP prompts.
18func Prompts() iter.Seq2[string, []*Prompt] {
19	return allPrompts.Seq2()
20}
21
22// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
23func GetPromptMessages(ctx context.Context, cfg *config.ConfigStore, clientName, promptName string, args map[string]string) ([]string, error) {
24	c, err := getOrRenewClient(ctx, cfg, clientName)
25	if err != nil {
26		return nil, err
27	}
28	result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
29		Name:      promptName,
30		Arguments: args,
31	})
32	if err != nil {
33		return nil, err
34	}
35
36	var messages []string
37	for _, msg := range result.Messages {
38		if msg.Role != "user" {
39			continue
40		}
41		if textContent, ok := msg.Content.(*mcp.TextContent); ok {
42			messages = append(messages, textContent.Text)
43		}
44	}
45	return messages, nil
46}
47
48// RefreshPrompts gets the updated list of prompts from the MCP and updates the
49// global state.
50func RefreshPrompts(ctx context.Context, name string) {
51	session, ok := sessions.Get(name)
52	if !ok {
53		slog.Warn("Refresh prompts: no session", "name", name)
54		return
55	}
56
57	prompts, err := getPrompts(ctx, session)
58	if err != nil {
59		updateState(name, StateError, err, nil, Counts{})
60		return
61	}
62
63	updatePrompts(name, prompts)
64
65	prev, _ := states.Get(name)
66	prev.Counts.Prompts = len(prompts)
67	updateState(name, StateConnected, nil, session, prev.Counts)
68}
69
70func getPrompts(ctx context.Context, c *ClientSession) ([]*Prompt, error) {
71	if c.InitializeResult().Capabilities.Prompts == nil {
72		return nil, nil
73	}
74	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
75	if err != nil {
76		return nil, err
77	}
78	return result.Prompts, nil
79}
80
81// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
82func updatePrompts(mcpName string, prompts []*Prompt) {
83	if len(prompts) == 0 {
84		allPrompts.Del(mcpName)
85		return
86	}
87	allPrompts.Set(mcpName, prompts)
88}