mcp-tools.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/fantasy/ai"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged MCPEventType = "state_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce sync.Once
 81	mcpTools     []ai.AgentTool
 82	mcpClients   = csync.NewMap[string, *client.Client]()
 83	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 85)
 86
 87type McpTool struct {
 88	mcpName         string
 89	tool            mcp.Tool
 90	permissions     permission.Service
 91	workingDir      string
 92	providerOptions ai.ProviderOptions
 93}
 94
 95func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
 96	m.providerOptions = opts
 97}
 98
 99func (m *McpTool) ProviderOptions() ai.ProviderOptions {
100	return m.providerOptions
101}
102
103func (b *McpTool) Name() string {
104	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
105}
106
107func (b *McpTool) Info() ai.ToolInfo {
108	required := b.tool.InputSchema.Required
109	if required == nil {
110		required = make([]string, 0)
111	}
112	parameters := b.tool.InputSchema.Properties
113	if parameters == nil {
114		parameters = make(map[string]any)
115	}
116	return ai.ToolInfo{
117		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
118		Description: b.tool.Description,
119		Parameters:  parameters,
120		Required:    required,
121	}
122}
123
124func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
125	var args map[string]any
126	if err := json.Unmarshal([]byte(input), &args); err != nil {
127		return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
128	}
129
130	c, err := getOrRenewClient(ctx, name)
131	if err != nil {
132		return ai.NewTextErrorResponse(err.Error()), nil
133	}
134	result, err := c.CallTool(ctx, mcp.CallToolRequest{
135		Params: mcp.CallToolParams{
136			Name:      toolName,
137			Arguments: args,
138		},
139	})
140	if err != nil {
141		return ai.NewTextErrorResponse(err.Error()), nil
142	}
143
144	output := make([]string, 0, len(result.Content))
145	for _, v := range result.Content {
146		if v, ok := v.(mcp.TextContent); ok {
147			output = append(output, v.Text)
148		} else {
149			output = append(output, fmt.Sprintf("%v", v))
150		}
151	}
152	return ai.NewTextResponse(strings.Join(output, "\n")), nil
153}
154
155func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
156	c, ok := mcpClients.Get(name)
157	if !ok {
158		return nil, fmt.Errorf("mcp '%s' not available", name)
159	}
160
161	cfg := config.Get()
162	m := cfg.MCP[name]
163	state, _ := mcpStates.Get(name)
164
165	timeout := mcpTimeout(m)
166	pingCtx, cancel := context.WithTimeout(ctx, timeout)
167	defer cancel()
168	err := c.Ping(pingCtx)
169	if err == nil {
170		return c, nil
171	}
172	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
173
174	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
175	if err != nil {
176		return nil, err
177	}
178
179	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
180	mcpClients.Set(name, c)
181	return c, nil
182}
183
184func (b *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
185	sessionID := GetSessionFromContext(ctx)
186	if sessionID == "" {
187		return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
188	}
189	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
190	p := b.permissions.Request(
191		permission.CreatePermissionRequest{
192			SessionID:   sessionID,
193			ToolCallID:  params.ID,
194			Path:        b.workingDir,
195			ToolName:    b.Info().Name,
196			Action:      "execute",
197			Description: permissionDescription,
198			Params:      params.Input,
199		},
200	)
201	if !p {
202		return ai.ToolResponse{}, permission.ErrorPermissionDenied
203	}
204
205	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
206}
207
208func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []ai.AgentTool {
209	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
210	if err != nil {
211		slog.Error("error listing tools", "error", err)
212		updateMCPState(name, MCPStateError, err, nil, 0)
213		c.Close()
214		mcpClients.Del(name)
215		return nil
216	}
217	mcpTools := make([]ai.AgentTool, 0, len(result.Tools))
218	for _, tool := range result.Tools {
219		mcpTools = append(mcpTools, &McpTool{
220			mcpName:     name,
221			tool:        tool,
222			permissions: permissions,
223			workingDir:  workingDir,
224		})
225	}
226	return mcpTools
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231	return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236	return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241	return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
246	info := MCPClientInfo{
247		Name:      name,
248		State:     state,
249		Error:     err,
250		Client:    client,
251		ToolCount: toolCount,
252	}
253	if state == MCPStateConnected {
254		info.ConnectedAt = time.Now()
255	}
256	mcpStates.Set(name, info)
257
258	// Publish state change event
259	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
260		Type:      MCPEventStateChanged,
261		Name:      name,
262		State:     state,
263		Error:     err,
264		ToolCount: toolCount,
265	})
266}
267
268// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
269func CloseMCPClients() error {
270	var errs []error
271	for name, c := range mcpClients.Seq2() {
272		if err := c.Close(); err != nil {
273			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
274		}
275	}
276	mcpBroker.Shutdown()
277	return errors.Join(errs...)
278}
279
280var mcpInitRequest = mcp.InitializeRequest{
281	Params: mcp.InitializeParams{
282		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
283		ClientInfo: mcp.Implementation{
284			Name:    "Crush",
285			Version: version.Version,
286		},
287	},
288}
289
290func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []ai.AgentTool {
291	var wg sync.WaitGroup
292	result := csync.NewSlice[ai.AgentTool]()
293
294	// Initialize states for all configured MCPs
295	for name, m := range cfg.MCP {
296		if m.Disabled {
297			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
298			slog.Debug("skipping disabled mcp", "name", name)
299			continue
300		}
301
302		// Set initial starting state
303		updateMCPState(name, MCPStateStarting, nil, nil, 0)
304
305		wg.Add(1)
306		go func(name string, m config.MCPConfig) {
307			defer func() {
308				wg.Done()
309				if r := recover(); r != nil {
310					var err error
311					switch v := r.(type) {
312					case error:
313						err = v
314					case string:
315						err = fmt.Errorf("panic: %s", v)
316					default:
317						err = fmt.Errorf("panic: %v", v)
318					}
319					updateMCPState(name, MCPStateError, err, nil, 0)
320					slog.Error("panic in mcp client initialization", "error", err, "name", name)
321				}
322			}()
323
324			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
325			defer cancel()
326			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
327			if err != nil {
328				return
329			}
330			mcpClients.Set(name, c)
331
332			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
333			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
334			result.Append(tools...)
335		}(name, m)
336	}
337	wg.Wait()
338	return slices.Collect(result.Seq())
339}
340
341func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
342	c, err := createMcpClient(name, m, resolver)
343	if err != nil {
344		updateMCPState(name, MCPStateError, err, nil, 0)
345		slog.Error("error creating mcp client", "error", err, "name", name)
346		return nil, err
347	}
348
349	timeout := mcpTimeout(m)
350	initCtx, cancel := context.WithTimeout(ctx, timeout)
351	defer cancel()
352
353	// Only call Start() for non-stdio clients, as stdio clients auto-start
354	if m.Type != config.MCPStdio {
355		if err := c.Start(initCtx); err != nil {
356			updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
357			slog.Error("error starting mcp client", "error", err, "name", name)
358			_ = c.Close()
359			return nil, err
360		}
361	}
362	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
363		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
364		slog.Error("error initializing mcp client", "error", err, "name", name)
365		_ = c.Close()
366		return nil, err
367	}
368
369	slog.Info("Initialized mcp client", "name", name)
370	return c, nil
371}
372
373func maybeTimeoutErr(err error, timeout time.Duration) error {
374	if errors.Is(err, context.DeadlineExceeded) {
375		return fmt.Errorf("timed out after %s", timeout)
376	}
377	return err
378}
379
380func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
381	switch m.Type {
382	case config.MCPStdio:
383		command, err := resolver.ResolveValue(m.Command)
384		if err != nil {
385			return nil, fmt.Errorf("invalid mcp command: %w", err)
386		}
387		if strings.TrimSpace(command) == "" {
388			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
389		}
390		return client.NewStdioMCPClientWithOptions(
391			home.Long(command),
392			m.ResolvedEnv(),
393			m.Args,
394			transport.WithCommandLogger(mcpLogger{name: name}),
395		)
396	case config.MCPHttp:
397		if strings.TrimSpace(m.URL) == "" {
398			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
399		}
400		return client.NewStreamableHttpClient(
401			m.URL,
402			transport.WithHTTPHeaders(m.ResolvedHeaders()),
403			transport.WithHTTPLogger(mcpLogger{name: name}),
404		)
405	case config.MCPSse:
406		if strings.TrimSpace(m.URL) == "" {
407			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
408		}
409		return client.NewSSEMCPClient(
410			m.URL,
411			client.WithHeaders(m.ResolvedHeaders()),
412			transport.WithSSELogger(mcpLogger{name: name}),
413		)
414	default:
415		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
416	}
417}
418
419// for MCP's clients.
420type mcpLogger struct{ name string }
421
422func (l mcpLogger) Errorf(format string, v ...any) {
423	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
424}
425
426func (l mcpLogger) Infof(format string, v ...any) {
427	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
428}
429
430func mcpTimeout(m config.MCPConfig) time.Duration {
431	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
432}