1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"net/http"
 12	"os/exec"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/modelcontextprotocol/go-sdk/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventToolsListChanged   MCPEventType = "tools_list_changed"
 57	MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type   MCPEventType
 63	Name   string
 64	State  MCPState
 65	Error  error
 66	Counts MCPCounts
 67}
 68
 69// MCPCounts number of available tools, prompts, etc.
 70type MCPCounts struct {
 71	Tools   int
 72	Prompts int
 73}
 74
 75// MCPClientInfo holds information about an MCP client's state
 76type MCPClientInfo struct {
 77	Name        string
 78	State       MCPState
 79	Error       error
 80	Client      *mcp.ClientSession
 81	Counts      MCPCounts
 82	ConnectedAt time.Time
 83}
 84
 85var (
 86	mcpToolsOnce      sync.Once
 87	mcpTools          = csync.NewMap[string, tools.BaseTool]()
 88	mcpClient2Tools   = csync.NewMap[string, []tools.BaseTool]()
 89	mcpClients        = csync.NewMap[string, *mcp.ClientSession]()
 90	mcpStates         = csync.NewMap[string, MCPClientInfo]()
 91	mcpBroker         = pubsub.NewBroker[MCPEvent]()
 92	mcpPrompts        = csync.NewMap[string, *mcp.Prompt]()
 93	mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
 94)
 95
 96type McpTool struct {
 97	mcpName     string
 98	tool        *mcp.Tool
 99	permissions permission.Service
100	workingDir  string
101}
102
103func (b *McpTool) Name() string {
104	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
105}
106
107func (b *McpTool) Info() tools.ToolInfo {
108	input := b.tool.InputSchema.(map[string]any)
109	required, _ := input["required"].([]string)
110	parameters, _ := input["properties"].(map[string]any)
111	return tools.ToolInfo{
112		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
113		Description: b.tool.Description,
114		Parameters:  parameters,
115		Required:    required,
116	}
117}
118
119func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
120	var args map[string]any
121	if err := json.Unmarshal([]byte(input), &args); err != nil {
122		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
123	}
124
125	c, err := getOrRenewClient(ctx, name)
126	if err != nil {
127		return tools.NewTextErrorResponse(err.Error()), nil
128	}
129	result, err := c.CallTool(ctx, &mcp.CallToolParams{
130		Name:      toolName,
131		Arguments: args,
132	})
133	if err != nil {
134		return tools.NewTextErrorResponse(err.Error()), nil
135	}
136
137	output := make([]string, 0, len(result.Content))
138	for _, v := range result.Content {
139		if vv, ok := v.(*mcp.TextContent); ok {
140			output = append(output, vv.Text)
141		} else {
142			output = append(output, fmt.Sprintf("%v", v))
143		}
144	}
145	return tools.NewTextResponse(strings.Join(output, "\n")), nil
146}
147
148func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
149	sess, ok := mcpClients.Get(name)
150	if !ok {
151		return nil, fmt.Errorf("mcp '%s' not available", name)
152	}
153
154	cfg := config.Get()
155	m := cfg.MCP[name]
156	state, _ := mcpStates.Get(name)
157
158	timeout := mcpTimeout(m)
159	pingCtx, cancel := context.WithTimeout(ctx, timeout)
160	defer cancel()
161	err := sess.Ping(pingCtx, nil)
162	if err == nil {
163		return sess, nil
164	}
165	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
166
167	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
168	if err != nil {
169		return nil, err
170	}
171
172	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
173	mcpClients.Set(name, sess)
174	return sess, nil
175}
176
177func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
178	sessionID, messageID := tools.GetContextValues(ctx)
179	if sessionID == "" || messageID == "" {
180		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
181	}
182	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
183	p := b.permissions.Request(
184		permission.CreatePermissionRequest{
185			SessionID:   sessionID,
186			ToolCallID:  params.ID,
187			Path:        b.workingDir,
188			ToolName:    b.Info().Name,
189			Action:      "execute",
190			Description: permissionDescription,
191			Params:      params.Input,
192		},
193	)
194	if !p {
195		return tools.ToolResponse{}, permission.ErrorPermissionDenied
196	}
197
198	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
199}
200
201func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
202	if c.InitializeResult().Capabilities.Tools == nil {
203		return nil, nil
204	}
205	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
206	if err != nil {
207		return nil, err
208	}
209	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
210	for _, tool := range result.Tools {
211		mcpTools = append(mcpTools, &McpTool{
212			mcpName:     name,
213			tool:        tool,
214			permissions: permissions,
215			workingDir:  workingDir,
216		})
217	}
218	return mcpTools, nil
219}
220
221// SubscribeMCPEvents returns a channel for MCP events
222func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
223	return mcpBroker.Subscribe(ctx)
224}
225
226// GetMCPStates returns the current state of all MCP clients
227func GetMCPStates() map[string]MCPClientInfo {
228	return maps.Collect(mcpStates.Seq2())
229}
230
231// GetMCPState returns the state of a specific MCP client
232func GetMCPState(name string) (MCPClientInfo, bool) {
233	return mcpStates.Get(name)
234}
235
236// updateMCPState updates the state of an MCP client and publishes an event
237func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
238	info := MCPClientInfo{
239		Name:   name,
240		State:  state,
241		Error:  err,
242		Client: client,
243		Counts: counts,
244	}
245	switch state {
246	case MCPStateConnected:
247		info.ConnectedAt = time.Now()
248	case MCPStateError:
249		updateMcpTools(name, nil)
250		updateMcpPrompts(name, nil)
251		mcpClients.Del(name)
252	}
253	mcpStates.Set(name, info)
254}
255
256// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
257func CloseMCPClients() error {
258	var errs []error
259	for name, c := range mcpClients.Seq2() {
260		if err := c.Close(); err != nil {
261			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
262		}
263	}
264	mcpBroker.Shutdown()
265	return errors.Join(errs...)
266}
267
268func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
269	var wg sync.WaitGroup
270	// Initialize states for all configured MCPs
271	for name, m := range cfg.MCP {
272		if m.Disabled {
273			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
274			slog.Debug("skipping disabled mcp", "name", name)
275			continue
276		}
277
278		// Set initial starting state
279		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
280
281		wg.Add(1)
282		go func(name string, m config.MCPConfig) {
283			defer func() {
284				wg.Done()
285				if r := recover(); r != nil {
286					var err error
287					switch v := r.(type) {
288					case error:
289						err = v
290					case string:
291						err = fmt.Errorf("panic: %s", v)
292					default:
293						err = fmt.Errorf("panic: %v", v)
294					}
295					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
296					slog.Error("panic in mcp client initialization", "error", err, "name", name)
297				}
298			}()
299
300			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
301			defer cancel()
302
303			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
304			if err != nil {
305				return
306			}
307
308			mcpClients.Set(name, c)
309
310			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
311			if err != nil {
312				slog.Error("error listing tools", "error", err)
313				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
314				c.Close()
315				return
316			}
317
318			prompts, err := getPrompts(ctx, c)
319			if err != nil {
320				slog.Error("error listing prompts", "error", err)
321				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
322				c.Close()
323				return
324			}
325
326			updateMcpTools(name, tools)
327			updateMcpPrompts(name, prompts)
328			mcpClients.Set(name, c)
329			counts := MCPCounts{
330				Tools:   len(tools),
331				Prompts: len(prompts),
332			}
333			updateMCPState(name, MCPStateConnected, nil, c, counts)
334		}(name, m)
335	}
336	wg.Wait()
337}
338
339// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
340func updateMcpTools(mcpName string, tools []tools.BaseTool) {
341	if len(tools) == 0 {
342		mcpClient2Tools.Del(mcpName)
343	} else {
344		mcpClient2Tools.Set(mcpName, tools)
345	}
346	for _, tools := range mcpClient2Tools.Seq2() {
347		for _, t := range tools {
348			mcpTools.Set(t.Name(), t)
349		}
350	}
351}
352
353func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
354	timeout := mcpTimeout(m)
355	mcpCtx, cancel := context.WithCancel(ctx)
356	cancelTimer := time.AfterFunc(timeout, cancel)
357
358	transport, err := createMCPTransport(mcpCtx, m, resolver)
359	if err != nil {
360		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
361		slog.Error("error creating mcp client", "error", err, "name", name)
362		return nil, err
363	}
364
365	client := mcp.NewClient(
366		&mcp.Implementation{
367			Name:    "crush",
368			Version: version.Version,
369			Title:   "Crush",
370		},
371		&mcp.ClientOptions{
372			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
373				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
374					Type: MCPEventToolsListChanged,
375					Name: name,
376				})
377			},
378			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
379				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380					Type: MCPEventPromptsListChanged,
381					Name: name,
382				})
383			},
384			KeepAlive: time.Minute * 10,
385		},
386	)
387
388	session, err := client.Connect(mcpCtx, transport, nil)
389	if err != nil {
390		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
391		slog.Error("error starting mcp client", "error", err, "name", name)
392		_ = session.Close()
393		cancel()
394		return nil, err
395	}
396
397	cancelTimer.Stop()
398	slog.Info("Initialized mcp client", "name", name)
399	return session, nil
400}
401
402func maybeTimeoutErr(err error, timeout time.Duration) error {
403	if errors.Is(err, context.Canceled) {
404		return fmt.Errorf("timed out after %s", timeout)
405	}
406	return err
407}
408
409func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
410	switch m.Type {
411	case config.MCPStdio:
412		command, err := resolver.ResolveValue(m.Command)
413		if err != nil {
414			return nil, fmt.Errorf("invalid mcp command: %w", err)
415		}
416		if strings.TrimSpace(command) == "" {
417			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
418		}
419		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
420		cmd.Env = m.ResolvedEnv()
421		return &mcp.CommandTransport{
422			Command: cmd,
423		}, nil
424	case config.MCPHttp:
425		if strings.TrimSpace(m.URL) == "" {
426			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
427		}
428		client := &http.Client{
429			Transport: &headerRoundTripper{
430				headers: m.ResolvedHeaders(),
431			},
432		}
433		return &mcp.StreamableClientTransport{
434			Endpoint:   m.URL,
435			HTTPClient: client,
436		}, nil
437	case config.MCPSSE:
438		if strings.TrimSpace(m.URL) == "" {
439			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
440		}
441		client := &http.Client{
442			Transport: &headerRoundTripper{
443				headers: m.ResolvedHeaders(),
444			},
445		}
446		return &mcp.SSEClientTransport{
447			Endpoint:   m.URL,
448			HTTPClient: client,
449		}, nil
450	default:
451		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
452	}
453}
454
455type headerRoundTripper struct {
456	headers map[string]string
457}
458
459func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
460	for k, v := range rt.headers {
461		req.Header.Set(k, v)
462	}
463	return http.DefaultTransport.RoundTrip(req)
464}
465
466func mcpTimeout(m config.MCPConfig) time.Duration {
467	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
468}
469
470func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
471	if c.InitializeResult().Capabilities.Prompts == nil {
472		return nil, nil
473	}
474	result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
475	if err != nil {
476		return nil, err
477	}
478	return result.Prompts, nil
479}
480
481// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
482func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
483	if len(prompts) == 0 {
484		mcpClient2Prompts.Del(mcpName)
485	} else {
486		mcpClient2Prompts.Set(mcpName, prompts)
487	}
488	for clientName, prompts := range mcpClient2Prompts.Seq2() {
489		for _, p := range prompts {
490			key := clientName + ":" + p.Name
491			mcpPrompts.Set(key, p)
492		}
493	}
494}
495
496// GetMCPPrompts returns all available MCP prompts.
497func GetMCPPrompts() map[string]*mcp.Prompt {
498	return maps.Collect(mcpPrompts.Seq2())
499}
500
501// GetMCPPrompt returns a specific MCP prompt by name.
502func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
503	return mcpPrompts.Get(name)
504}
505
506// GetMCPPromptsByClient returns all prompts for a specific MCP client.
507func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
508	return mcpClient2Prompts.Get(clientName)
509}
510
511// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
512func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
513	c, err := getOrRenewClient(ctx, clientName)
514	if err != nil {
515		return nil, err
516	}
517
518	return c.GetPrompt(ctx, &mcp.GetPromptParams{
519		Name:      promptName,
520		Arguments: args,
521	})
522}