prompts.go

 1package mcp
 2
 3import (
 4	"context"
 5	"iter"
 6	"log/slog"
 7
 8	"github.com/charmbracelet/crush/internal/csync"
 9	"github.com/modelcontextprotocol/go-sdk/mcp"
10)
11
12type Prompt = mcp.Prompt
13
14var (
15	allPrompts    = csync.NewMap[string, *Prompt]()
16	clientPrompts = csync.NewMap[string, []*Prompt]()
17)
18
19// Prompts returns all available MCP prompts.
20func Prompts() iter.Seq2[string, *Prompt] {
21	return allPrompts.Seq2()
22}
23
24// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
25func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
26	c, err := getOrRenewClient(ctx, clientName)
27	if err != nil {
28		return nil, err
29	}
30	result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
31		Name:      promptName,
32		Arguments: args,
33	})
34	if err != nil {
35		return nil, err
36	}
37
38	var messages []string
39	for _, msg := range result.Messages {
40		if msg.Role != "user" {
41			continue
42		}
43		if textContent, ok := msg.Content.(*mcp.TextContent); ok {
44			messages = append(messages, textContent.Text)
45		}
46	}
47	return messages, nil
48}
49
50// RefreshPrompts gets the updated list of prompts from the MCP and updates the
51// global state.
52func RefreshPrompts(ctx context.Context, name string) {
53	session, ok := sessions.Get(name)
54	if !ok {
55		slog.Warn("refresh prompts: no session", "name", name)
56		return
57	}
58
59	prompts, err := getPrompts(ctx, session)
60	if err != nil {
61		updateState(name, StateError, err, nil, Counts{})
62		return
63	}
64
65	updatePrompts(name, prompts)
66
67	prev, _ := states.Get(name)
68	updateState(name, StateConnected, nil, session, Counts{
69		Prompts: len(prompts),
70		Tools:   prev.Counts.Tools,
71	})
72}
73
74func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
75	if c.InitializeResult().Capabilities.Prompts == nil {
76		return nil, nil
77	}
78	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
79	if err != nil {
80		return nil, err
81	}
82	return result.Prompts, nil
83}
84
85// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
86func updatePrompts(mcpName string, prompts []*Prompt) {
87	if len(prompts) == 0 {
88		clientPrompts.Del(mcpName)
89	} else {
90		clientPrompts.Set(mcpName, prompts)
91	}
92	for mcpName, prompts := range clientPrompts.Seq2() {
93		for _, p := range prompts {
94			key := mcpName + ":" + p.Name
95			allPrompts.Set(key, p)
96		}
97	}
98}