mcp-tools.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os/exec"
 14	"slices"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/charmbracelet/fantasy/ai"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29// MCPState represents the current state of an MCP client
 30type MCPState int
 31
 32const (
 33	MCPStateDisabled MCPState = iota
 34	MCPStateStarting
 35	MCPStateConnected
 36	MCPStateError
 37)
 38
 39func (s MCPState) String() string {
 40	switch s {
 41	case MCPStateDisabled:
 42		return "disabled"
 43	case MCPStateStarting:
 44		return "starting"
 45	case MCPStateConnected:
 46		return "connected"
 47	case MCPStateError:
 48		return "error"
 49	default:
 50		return "unknown"
 51	}
 52}
 53
 54// MCPEventType represents the type of MCP event
 55type MCPEventType string
 56
 57const (
 58	MCPEventStateChanged     MCPEventType = "state_changed"
 59	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 60)
 61
 62// MCPEvent represents an event in the MCP system
 63type MCPEvent struct {
 64	Type      MCPEventType
 65	Name      string
 66	State     MCPState
 67	Error     error
 68	ToolCount int
 69}
 70
 71// MCPClientInfo holds information about an MCP client's state
 72type MCPClientInfo struct {
 73	Name        string
 74	State       MCPState
 75	Error       error
 76	Client      *mcp.ClientSession
 77	ToolCount   int
 78	ConnectedAt time.Time
 79}
 80
 81var (
 82	mcpToolsOnce    sync.Once
 83	mcpTools        = csync.NewMap[string, *McpTool]()
 84	mcpClient2Tools = csync.NewMap[string, []*McpTool]()
 85	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 86	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 87	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 88)
 89
 90type McpTool struct {
 91	mcpName         string
 92	tool            *mcp.Tool
 93	permissions     permission.Service
 94	workingDir      string
 95	providerOptions ai.ProviderOptions
 96}
 97
 98func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
 99	m.providerOptions = opts
100}
101
102func (m *McpTool) ProviderOptions() ai.ProviderOptions {
103	return m.providerOptions
104}
105
106func (m *McpTool) Name() string {
107	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
108}
109
110func (m *McpTool) MCP() string {
111	return m.mcpName
112}
113
114func (m *McpTool) MCPToolName() string {
115	return m.tool.Name
116}
117
118func (b *McpTool) Info() ai.ToolInfo {
119	input := b.tool.InputSchema.(map[string]any)
120	required, _ := input["required"].([]string)
121	if required == nil {
122		required = make([]string, 0)
123	}
124	parameters, _ := input["properties"].(map[string]any)
125	if parameters == nil {
126		parameters = make(map[string]any)
127	}
128	return ai.ToolInfo{
129		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
130		Description: b.tool.Description,
131		Parameters:  parameters,
132		Required:    required,
133	}
134}
135
136func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
137	var args map[string]any
138	if err := json.Unmarshal([]byte(input), &args); err != nil {
139		return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
140	}
141
142	c, err := getOrRenewClient(ctx, name)
143	if err != nil {
144		return ai.NewTextErrorResponse(err.Error()), nil
145	}
146	result, err := c.CallTool(ctx, &mcp.CallToolParams{
147		Name:      toolName,
148		Arguments: args,
149	})
150	if err != nil {
151		return ai.NewTextErrorResponse(err.Error()), nil
152	}
153
154	output := make([]string, 0, len(result.Content))
155	for _, v := range result.Content {
156		if vv, ok := v.(*mcp.TextContent); ok {
157			output = append(output, vv.Text)
158		} else {
159			output = append(output, fmt.Sprintf("%v", v))
160		}
161	}
162	return ai.NewTextResponse(strings.Join(output, "\n")), nil
163}
164
165func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
166	sess, ok := mcpClients.Get(name)
167	if !ok {
168		return nil, fmt.Errorf("mcp '%s' not available", name)
169	}
170
171	cfg := config.Get()
172	m := cfg.MCP[name]
173	state, _ := mcpStates.Get(name)
174
175	timeout := mcpTimeout(m)
176	pingCtx, cancel := context.WithTimeout(ctx, timeout)
177	defer cancel()
178	err := sess.Ping(pingCtx, nil)
179	if err == nil {
180		return sess, nil
181	}
182	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
183
184	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
185	if err != nil {
186		return nil, err
187	}
188
189	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
190	mcpClients.Set(name, sess)
191	return sess, nil
192}
193
194func (m *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
195	sessionID := GetSessionFromContext(ctx)
196	if sessionID == "" {
197		return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
198	}
199	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
200	p := m.permissions.Request(
201		permission.CreatePermissionRequest{
202			SessionID:   sessionID,
203			ToolCallID:  params.ID,
204			Path:        m.workingDir,
205			ToolName:    m.Info().Name,
206			Action:      "execute",
207			Description: permissionDescription,
208			Params:      params.Input,
209		},
210	)
211	if !p {
212		return ai.ToolResponse{}, permission.ErrorPermissionDenied
213	}
214
215	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
216}
217
218func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
219	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
220	if err != nil {
221		return nil, err
222	}
223	mcpTools := make([]*McpTool, 0, len(result.Tools))
224	for _, tool := range result.Tools {
225		mcpTools = append(mcpTools, &McpTool{
226			mcpName:     name,
227			tool:        tool,
228			permissions: permissions,
229			workingDir:  workingDir,
230		})
231	}
232	return mcpTools, nil
233}
234
235// SubscribeMCPEvents returns a channel for MCP events
236func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
237	return mcpBroker.Subscribe(ctx)
238}
239
240// GetMCPStates returns the current state of all MCP clients
241func GetMCPStates() map[string]MCPClientInfo {
242	return maps.Collect(mcpStates.Seq2())
243}
244
245// GetMCPState returns the state of a specific MCP client
246func GetMCPState(name string) (MCPClientInfo, bool) {
247	return mcpStates.Get(name)
248}
249
250// updateMCPState updates the state of an MCP client and publishes an event
251func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
252	info := MCPClientInfo{
253		Name:      name,
254		State:     state,
255		Error:     err,
256		Client:    client,
257		ToolCount: toolCount,
258	}
259	switch state {
260	case MCPStateConnected:
261		info.ConnectedAt = time.Now()
262	case MCPStateError:
263		updateMcpTools(name, nil)
264		mcpClients.Del(name)
265	}
266	mcpStates.Set(name, info)
267
268	// Publish state change event
269	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270		Type:      MCPEventStateChanged,
271		Name:      name,
272		State:     state,
273		Error:     err,
274		ToolCount: toolCount,
275	})
276}
277
278// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
279func CloseMCPClients() error {
280	var errs []error
281	for name, c := range mcpClients.Seq2() {
282		if err := c.Close(); err != nil &&
283			!errors.Is(err, io.EOF) &&
284			!errors.Is(err, context.Canceled) &&
285			err.Error() != "signal: killed" {
286			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
287		}
288	}
289	mcpBroker.Shutdown()
290	return errors.Join(errs...)
291}
292
293func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
294	mcpToolsOnce.Do(func() {
295		var wg sync.WaitGroup
296		// Initialize states for all configured MCPs
297		for name, m := range cfg.MCP {
298			if m.Disabled {
299				updateMCPState(name, MCPStateDisabled, nil, nil, 0)
300				slog.Debug("skipping disabled mcp", "name", name)
301				continue
302			}
303
304			// Set initial starting state
305			updateMCPState(name, MCPStateStarting, nil, nil, 0)
306
307			wg.Add(1)
308			go func(name string, m config.MCPConfig) {
309				defer func() {
310					wg.Done()
311					if r := recover(); r != nil {
312						var err error
313						switch v := r.(type) {
314						case error:
315							err = v
316						case string:
317							err = fmt.Errorf("panic: %s", v)
318						default:
319							err = fmt.Errorf("panic: %v", v)
320						}
321						updateMCPState(name, MCPStateError, err, nil, 0)
322						slog.Error("panic in mcp client initialization", "error", err, "name", name)
323					}
324				}()
325				c, err := createMCPSession(ctx, name, m, cfg.Resolver())
326				if err != nil {
327					return
328				}
329
330				mcpClients.Set(name, c)
331
332				tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333				if err != nil {
334					slog.Error("error listing tools", "error", err)
335					updateMCPState(name, MCPStateError, err, nil, 0)
336					c.Close()
337					return
338				}
339
340				updateMcpTools(name, tools)
341				mcpClients.Set(name, c)
342				updateMCPState(name, MCPStateConnected, nil, c, len(tools))
343			}(name, m)
344		}
345		wg.Wait()
346	})
347	return slices.Collect(mcpTools.Seq())
348}
349
350// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
351func updateMcpTools(mcpName string, tools []*McpTool) {
352	if len(tools) == 0 {
353		mcpClient2Tools.Del(mcpName)
354	} else {
355		mcpClient2Tools.Set(mcpName, tools)
356	}
357	for _, tools := range mcpClient2Tools.Seq2() {
358		for _, t := range tools {
359			mcpTools.Set(t.Name(), t)
360		}
361	}
362}
363
364func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
365	timeout := mcpTimeout(m)
366	mcpCtx, cancel := context.WithCancel(ctx)
367	cancelTimer := time.AfterFunc(timeout, cancel)
368
369	transport, err := createMCPTransport(mcpCtx, m, resolver)
370	if err != nil {
371		updateMCPState(name, MCPStateError, err, nil, 0)
372		slog.Error("error creating mcp client", "error", err, "name", name)
373		return nil, err
374	}
375
376	client := mcp.NewClient(
377		&mcp.Implementation{
378			Name:    "crush",
379			Version: version.Version,
380			Title:   "Crush",
381		},
382		&mcp.ClientOptions{
383			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
384				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
385					Type: MCPEventToolsListChanged,
386					Name: name,
387				})
388			},
389			KeepAlive: time.Minute * 10,
390		},
391	)
392
393	session, err := client.Connect(mcpCtx, transport, nil)
394	if err != nil {
395		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
396		slog.Error("error starting mcp client", "error", err, "name", name)
397		cancel()
398		return nil, err
399	}
400
401	cancelTimer.Stop()
402	slog.Info("Initialized mcp client", "name", name)
403	return session, nil
404}
405
406func maybeTimeoutErr(err error, timeout time.Duration) error {
407	if errors.Is(err, context.Canceled) {
408		return fmt.Errorf("timed out after %s", timeout)
409	}
410	return err
411}
412
413func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
414	switch m.Type {
415	case config.MCPStdio:
416		command, err := resolver.ResolveValue(m.Command)
417		if err != nil {
418			return nil, fmt.Errorf("invalid mcp command: %w", err)
419		}
420		if strings.TrimSpace(command) == "" {
421			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
422		}
423		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
424		cmd.Env = m.ResolvedEnv()
425		return &mcp.CommandTransport{
426			Command: cmd,
427		}, nil
428	case config.MCPHttp:
429		if strings.TrimSpace(m.URL) == "" {
430			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
431		}
432		client := &http.Client{
433			Transport: &headerRoundTripper{
434				headers: m.ResolvedHeaders(),
435			},
436		}
437		return &mcp.StreamableClientTransport{
438			Endpoint:   m.URL,
439			HTTPClient: client,
440		}, nil
441	case config.MCPSSE:
442		if strings.TrimSpace(m.URL) == "" {
443			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
444		}
445		client := &http.Client{
446			Transport: &headerRoundTripper{
447				headers: m.ResolvedHeaders(),
448			},
449		}
450		return &mcp.SSEClientTransport{
451			Endpoint:   m.URL,
452			HTTPClient: client,
453		}, nil
454	default:
455		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
456	}
457}
458
459type headerRoundTripper struct {
460	headers map[string]string
461}
462
463func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
464	for k, v := range rt.headers {
465		req.Header.Set(k, v)
466	}
467	return http.DefaultTransport.RoundTrip(req)
468}
469
470func mcpTimeout(m config.MCPConfig) time.Duration {
471	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
472}