1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"slices"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"charm.land/fantasy"
 21	"github.com/charmbracelet/crush/internal/config"
 22	"github.com/charmbracelet/crush/internal/csync"
 23	"github.com/charmbracelet/crush/internal/home"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/pubsub"
 26	"github.com/charmbracelet/crush/internal/version"
 27	"github.com/modelcontextprotocol/go-sdk/mcp"
 28)
 29
 30// MCPState represents the current state of an MCP client
 31type MCPState int
 32
 33const (
 34	MCPStateDisabled MCPState = iota
 35	MCPStateStarting
 36	MCPStateConnected
 37	MCPStateError
 38)
 39
 40func (s MCPState) String() string {
 41	switch s {
 42	case MCPStateDisabled:
 43		return "disabled"
 44	case MCPStateStarting:
 45		return "starting"
 46	case MCPStateConnected:
 47		return "connected"
 48	case MCPStateError:
 49		return "error"
 50	default:
 51		return "unknown"
 52	}
 53}
 54
 55// MCPEventType represents the type of MCP event
 56type MCPEventType string
 57
 58const (
 59	MCPEventStateChanged     MCPEventType = "state_changed"
 60	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 61)
 62
 63// MCPEvent represents an event in the MCP system
 64type MCPEvent struct {
 65	Type      MCPEventType
 66	Name      string
 67	State     MCPState
 68	Error     error
 69	ToolCount int
 70}
 71
 72// MCPClientInfo holds information about an MCP client's state
 73type MCPClientInfo struct {
 74	Name        string
 75	State       MCPState
 76	Error       error
 77	Client      *mcp.ClientSession
 78	ToolCount   int
 79	ConnectedAt time.Time
 80}
 81
 82var (
 83	mcpToolsOnce    sync.Once
 84	mcpTools        = csync.NewMap[string, *McpTool]()
 85	mcpClient2Tools = csync.NewMap[string, []*McpTool]()
 86	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 87	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 88	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 89)
 90
 91type McpTool struct {
 92	mcpName         string
 93	tool            *mcp.Tool
 94	permissions     permission.Service
 95	workingDir      string
 96	providerOptions fantasy.ProviderOptions
 97}
 98
 99func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) {
100	m.providerOptions = opts
101}
102
103func (m *McpTool) ProviderOptions() fantasy.ProviderOptions {
104	return m.providerOptions
105}
106
107func (m *McpTool) Name() string {
108	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
109}
110
111func (m *McpTool) MCP() string {
112	return m.mcpName
113}
114
115func (m *McpTool) MCPToolName() string {
116	return m.tool.Name
117}
118
119func (b *McpTool) Info() fantasy.ToolInfo {
120	parameters := make(map[string]any)
121	required := make([]string, 0)
122
123	if input, ok := b.tool.InputSchema.(map[string]any); ok {
124		if props, ok := input["properties"].(map[string]any); ok {
125			parameters = props
126		}
127		if req, ok := input["required"].([]any); ok {
128			// Convert []any -> []string when elements are strings
129			for _, v := range req {
130				if s, ok := v.(string); ok {
131					required = append(required, s)
132				}
133			}
134		} else if reqStr, ok := input["required"].([]string); ok {
135			// Handle case where it's already []string
136			required = reqStr
137		}
138	}
139
140	return fantasy.ToolInfo{
141		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
142		Description: b.tool.Description,
143		Parameters:  parameters,
144		Required:    required,
145	}
146}
147
148func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
149	var args map[string]any
150	if err := json.Unmarshal([]byte(input), &args); err != nil {
151		return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
152	}
153
154	c, err := getOrRenewClient(ctx, name)
155	if err != nil {
156		return fantasy.NewTextErrorResponse(err.Error()), nil
157	}
158	result, err := c.CallTool(ctx, &mcp.CallToolParams{
159		Name:      toolName,
160		Arguments: args,
161	})
162	if err != nil {
163		return fantasy.NewTextErrorResponse(err.Error()), nil
164	}
165
166	output := make([]string, 0, len(result.Content))
167	for _, v := range result.Content {
168		if vv, ok := v.(*mcp.TextContent); ok {
169			output = append(output, vv.Text)
170		} else {
171			output = append(output, fmt.Sprintf("%v", v))
172		}
173	}
174	return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
175}
176
177func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
178	sess, ok := mcpClients.Get(name)
179	if !ok {
180		return nil, fmt.Errorf("mcp '%s' not available", name)
181	}
182
183	cfg := config.Get()
184	m := cfg.MCP[name]
185	state, _ := mcpStates.Get(name)
186
187	timeout := mcpTimeout(m)
188	pingCtx, cancel := context.WithTimeout(ctx, timeout)
189	defer cancel()
190	err := sess.Ping(pingCtx, nil)
191	if err == nil {
192		return sess, nil
193	}
194	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
195
196	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
197	if err != nil {
198		return nil, err
199	}
200
201	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
202	mcpClients.Set(name, sess)
203	return sess, nil
204}
205
206func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
207	sessionID := GetSessionFromContext(ctx)
208	if sessionID == "" {
209		return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
210	}
211	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
212	p := m.permissions.Request(
213		permission.CreatePermissionRequest{
214			SessionID:   sessionID,
215			ToolCallID:  params.ID,
216			Path:        m.workingDir,
217			ToolName:    m.Info().Name,
218			Action:      "execute",
219			Description: permissionDescription,
220			Params:      params.Input,
221		},
222	)
223	if !p {
224		return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
225	}
226
227	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
228}
229
230func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
231	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
232	if err != nil {
233		return nil, err
234	}
235	mcpTools := make([]*McpTool, 0, len(result.Tools))
236	for _, tool := range result.Tools {
237		mcpTools = append(mcpTools, &McpTool{
238			mcpName:     name,
239			tool:        tool,
240			permissions: permissions,
241			workingDir:  workingDir,
242		})
243	}
244	return mcpTools, nil
245}
246
247// SubscribeMCPEvents returns a channel for MCP events
248func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
249	return mcpBroker.Subscribe(ctx)
250}
251
252// GetMCPStates returns the current state of all MCP clients
253func GetMCPStates() map[string]MCPClientInfo {
254	return maps.Collect(mcpStates.Seq2())
255}
256
257// GetMCPState returns the state of a specific MCP client
258func GetMCPState(name string) (MCPClientInfo, bool) {
259	return mcpStates.Get(name)
260}
261
262// updateMCPState updates the state of an MCP client and publishes an event
263func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
264	info := MCPClientInfo{
265		Name:      name,
266		State:     state,
267		Error:     err,
268		Client:    client,
269		ToolCount: toolCount,
270	}
271	switch state {
272	case MCPStateConnected:
273		info.ConnectedAt = time.Now()
274	case MCPStateError:
275		updateMcpTools(name, nil)
276		mcpClients.Del(name)
277	}
278	mcpStates.Set(name, info)
279
280	// Publish state change event
281	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
282		Type:      MCPEventStateChanged,
283		Name:      name,
284		State:     state,
285		Error:     err,
286		ToolCount: toolCount,
287	})
288}
289
290// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
291func CloseMCPClients() error {
292	var errs []error
293	for name, c := range mcpClients.Seq2() {
294		if err := c.Close(); err != nil &&
295			!errors.Is(err, io.EOF) &&
296			!errors.Is(err, context.Canceled) &&
297			err.Error() != "signal: killed" {
298			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
299		}
300	}
301	mcpBroker.Shutdown()
302	return errors.Join(errs...)
303}
304
305func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
306	mcpToolsOnce.Do(func() {
307		var wg sync.WaitGroup
308		// Initialize states for all configured MCPs
309		for name, m := range cfg.MCP {
310			if m.Disabled {
311				updateMCPState(name, MCPStateDisabled, nil, nil, 0)
312				slog.Debug("skipping disabled mcp", "name", name)
313				continue
314			}
315
316			// Set initial starting state
317			updateMCPState(name, MCPStateStarting, nil, nil, 0)
318
319			wg.Add(1)
320			go func(name string, m config.MCPConfig) {
321				defer func() {
322					wg.Done()
323					if r := recover(); r != nil {
324						var err error
325						switch v := r.(type) {
326						case error:
327							err = v
328						case string:
329							err = fmt.Errorf("panic: %s", v)
330						default:
331							err = fmt.Errorf("panic: %v", v)
332						}
333						updateMCPState(name, MCPStateError, err, nil, 0)
334						slog.Error("panic in mcp client initialization", "error", err, "name", name)
335					}
336				}()
337
338				ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
339				defer cancel()
340
341				c, err := createMCPSession(ctx, name, m, cfg.Resolver())
342				if err != nil {
343					return
344				}
345
346				mcpClients.Set(name, c)
347
348				tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
349				if err != nil {
350					slog.Error("error listing tools", "error", err)
351					updateMCPState(name, MCPStateError, err, nil, 0)
352					c.Close()
353					return
354				}
355
356				updateMcpTools(name, tools)
357				mcpClients.Set(name, c)
358				updateMCPState(name, MCPStateConnected, nil, c, len(tools))
359			}(name, m)
360		}
361		wg.Wait()
362	})
363	return slices.Collect(mcpTools.Seq())
364}
365
366// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
367func updateMcpTools(mcpName string, tools []*McpTool) {
368	if len(tools) == 0 {
369		mcpClient2Tools.Del(mcpName)
370	} else {
371		mcpClient2Tools.Set(mcpName, tools)
372	}
373	for _, tools := range mcpClient2Tools.Seq2() {
374		for _, t := range tools {
375			mcpTools.Set(t.Name(), t)
376		}
377	}
378}
379
380func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
381	timeout := mcpTimeout(m)
382	mcpCtx, cancel := context.WithCancel(ctx)
383	cancelTimer := time.AfterFunc(timeout, cancel)
384
385	transport, err := createMCPTransport(mcpCtx, m, resolver)
386	if err != nil {
387		updateMCPState(name, MCPStateError, err, nil, 0)
388		slog.Error("error creating mcp client", "error", err, "name", name)
389		cancel()
390		cancelTimer.Stop()
391		return nil, err
392	}
393
394	client := mcp.NewClient(
395		&mcp.Implementation{
396			Name:    "crush",
397			Version: version.Version,
398			Title:   "Crush",
399		},
400		&mcp.ClientOptions{
401			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
402				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
403					Type: MCPEventToolsListChanged,
404					Name: name,
405				})
406			},
407			KeepAlive: time.Minute * 10,
408		},
409	)
410
411	session, err := client.Connect(mcpCtx, transport, nil)
412	if err != nil {
413		err = maybeStdioErr(err, transport)
414		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
415		slog.Error("error starting mcp client", "error", err, "name", name)
416		cancel()
417		cancelTimer.Stop()
418		return nil, err
419	}
420
421	cancelTimer.Stop()
422	slog.Info("Initialized mcp client", "name", name)
423	return session, nil
424}
425
426// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
427// to parse, and the cli will then close it, causing the EOF error.
428// so, if we got an EOF err, and the transport is STDIO, we try to exec it
429// again with a timeout and collect the output so we can add details to the
430// error.
431// this happens particularly when starting things with npx, e.g. if node can't
432// be found or some other error like that.
433func maybeStdioErr(err error, transport mcp.Transport) error {
434	if !errors.Is(err, io.EOF) {
435		return err
436	}
437	ct, ok := transport.(*mcp.CommandTransport)
438	if !ok {
439		return err
440	}
441	if err2 := stdioMCPCheck(ct.Command); err2 != nil {
442		err = errors.Join(err, err2)
443	}
444	return err
445}
446
447func maybeTimeoutErr(err error, timeout time.Duration) error {
448	if errors.Is(err, context.Canceled) {
449		return fmt.Errorf("timed out after %s", timeout)
450	}
451	return err
452}
453
454func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
455	switch m.Type {
456	case config.MCPStdio:
457		command, err := resolver.ResolveValue(m.Command)
458		if err != nil {
459			return nil, fmt.Errorf("invalid mcp command: %w", err)
460		}
461		if strings.TrimSpace(command) == "" {
462			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
463		}
464		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
465		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
466		return &mcp.CommandTransport{
467			Command: cmd,
468		}, nil
469	case config.MCPHttp:
470		if strings.TrimSpace(m.URL) == "" {
471			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
472		}
473		client := &http.Client{
474			Transport: &headerRoundTripper{
475				headers: m.ResolvedHeaders(),
476			},
477		}
478		return &mcp.StreamableClientTransport{
479			Endpoint:   m.URL,
480			HTTPClient: client,
481		}, nil
482	case config.MCPSSE:
483		if strings.TrimSpace(m.URL) == "" {
484			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
485		}
486		client := &http.Client{
487			Transport: &headerRoundTripper{
488				headers: m.ResolvedHeaders(),
489			},
490		}
491		return &mcp.SSEClientTransport{
492			Endpoint:   m.URL,
493			HTTPClient: client,
494		}, nil
495	default:
496		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
497	}
498}
499
500type headerRoundTripper struct {
501	headers map[string]string
502}
503
504func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
505	for k, v := range rt.headers {
506		req.Header.Set(k, v)
507	}
508	return http.DefaultTransport.RoundTrip(req)
509}
510
511func mcpTimeout(m config.MCPConfig) time.Duration {
512	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
513}
514
515func stdioMCPCheck(old *exec.Cmd) error {
516	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
517	defer cancel()
518	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
519	cmd.Env = old.Env
520	out, err := cmd.CombinedOutput()
521	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
522		return nil
523	}
524	return fmt.Errorf("%w: %s", err, string(out))
525}