prompts.go

 1package mcp
 2
 3import (
 4	"context"
 5	"iter"
 6	"log/slog"
 7
 8	"git.secluded.site/crush/internal/csync"
 9	"github.com/modelcontextprotocol/go-sdk/mcp"
10)
11
12type Prompt = mcp.Prompt
13
14var allPrompts = csync.NewMap[string, []*Prompt]()
15
16// Prompts returns all available MCP prompts.
17func Prompts() iter.Seq2[string, []*Prompt] {
18	return allPrompts.Seq2()
19}
20
21// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
22func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
23	c, err := getOrRenewClient(ctx, clientName)
24	if err != nil {
25		return nil, err
26	}
27	result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
28		Name:      promptName,
29		Arguments: args,
30	})
31	if err != nil {
32		return nil, err
33	}
34
35	var messages []string
36	for _, msg := range result.Messages {
37		if msg.Role != "user" {
38			continue
39		}
40		if textContent, ok := msg.Content.(*mcp.TextContent); ok {
41			messages = append(messages, textContent.Text)
42		}
43	}
44	return messages, nil
45}
46
47// RefreshPrompts gets the updated list of prompts from the MCP and updates the
48// global state.
49func RefreshPrompts(ctx context.Context, name string) {
50	session, ok := sessions.Get(name)
51	if !ok {
52		slog.Warn("refresh prompts: no session", "name", name)
53		return
54	}
55
56	prompts, err := getPrompts(ctx, session)
57	if err != nil {
58		updateState(name, StateError, err, nil, Counts{})
59		return
60	}
61
62	updatePrompts(name, prompts)
63
64	prev, _ := states.Get(name)
65	prev.Counts.Prompts = len(prompts)
66	updateState(name, StateConnected, nil, session, prev.Counts)
67}
68
69func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
70	if c.InitializeResult().Capabilities.Prompts == nil {
71		return nil, nil
72	}
73	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
74	if err != nil {
75		return nil, err
76	}
77	return result.Prompts, nil
78}
79
80// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
81func updatePrompts(mcpName string, prompts []*Prompt) {
82	if len(prompts) == 0 {
83		allPrompts.Del(mcpName)
84		return
85	}
86	allPrompts.Set(mcpName, prompts)
87}