mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/mark3labs/mcp-go/client"
 23	"github.com/mark3labs/mcp-go/client/transport"
 24	"github.com/mark3labs/mcp-go/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) MarshalText() ([]byte, error) {
 38	return []byte(s.String()), nil
 39}
 40
 41func (s *MCPState) UnmarshalText(data []byte) error {
 42	switch string(data) {
 43	case "disabled":
 44		*s = MCPStateDisabled
 45	case "starting":
 46		*s = MCPStateStarting
 47	case "connected":
 48		*s = MCPStateConnected
 49	case "error":
 50		*s = MCPStateError
 51	default:
 52		return fmt.Errorf("unknown mcp state: %s", data)
 53	}
 54	return nil
 55}
 56
 57func (s MCPState) String() string {
 58	switch s {
 59	case MCPStateDisabled:
 60		return "disabled"
 61	case MCPStateStarting:
 62		return "starting"
 63	case MCPStateConnected:
 64		return "connected"
 65	case MCPStateError:
 66		return "error"
 67	default:
 68		return "unknown"
 69	}
 70}
 71
 72// MCPEventType represents the type of MCP event
 73type MCPEventType string
 74
 75const (
 76	MCPEventStateChanged MCPEventType = "state_changed"
 77)
 78
 79func (t MCPEventType) MarshalText() ([]byte, error) {
 80	return []byte(t), nil
 81}
 82
 83func (t *MCPEventType) UnmarshalText(data []byte) error {
 84	*t = MCPEventType(data)
 85	return nil
 86}
 87
 88// MCPEvent represents an event in the MCP system
 89type MCPEvent struct {
 90	Type      MCPEventType `json:"type"`
 91	Name      string       `json:"name"`
 92	State     MCPState     `json:"state"`
 93	Error     error        `json:"error,omitempty"`
 94	ToolCount int          `json:"tool_count,omitempty"`
 95}
 96
 97// MCPClientInfo holds information about an MCP client's state
 98type MCPClientInfo struct {
 99	Name        string
100	State       MCPState
101	Error       error
102	Client      *client.Client
103	ToolCount   int
104	ConnectedAt time.Time
105}
106
107var (
108	mcpToolsOnce sync.Once
109	mcpTools     []tools.BaseTool
110	mcpClients   = csync.NewMap[string, *client.Client]()
111	mcpStates    = csync.NewMap[string, MCPClientInfo]()
112	mcpBroker    = pubsub.NewBroker[MCPEvent]()
113)
114
115type McpTool struct {
116	mcpName     string
117	tool        mcp.Tool
118	permissions permission.Service
119	workingDir  string
120}
121
122func (b *McpTool) Name() string {
123	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
124}
125
126func (b *McpTool) Info() tools.ToolInfo {
127	required := b.tool.InputSchema.Required
128	if required == nil {
129		required = make([]string, 0)
130	}
131	parameters := b.tool.InputSchema.Properties
132	if parameters == nil {
133		parameters = make(map[string]any)
134	}
135	return tools.ToolInfo{
136		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
137		Description: b.tool.Description,
138		Parameters:  parameters,
139		Required:    required,
140	}
141}
142
143func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
144	var args map[string]any
145	if err := json.Unmarshal([]byte(input), &args); err != nil {
146		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
147	}
148
149	c, err := getOrRenewClient(ctx, name)
150	if err != nil {
151		return tools.NewTextErrorResponse(err.Error()), nil
152	}
153	result, err := c.CallTool(ctx, mcp.CallToolRequest{
154		Params: mcp.CallToolParams{
155			Name:      toolName,
156			Arguments: args,
157		},
158	})
159	if err != nil {
160		return tools.NewTextErrorResponse(err.Error()), nil
161	}
162
163	output := make([]string, 0, len(result.Content))
164	for _, v := range result.Content {
165		if v, ok := v.(mcp.TextContent); ok {
166			output = append(output, v.Text)
167		} else {
168			output = append(output, fmt.Sprintf("%v", v))
169		}
170	}
171	return tools.NewTextResponse(strings.Join(output, "\n")), nil
172}
173
174func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
175	c, ok := mcpClients.Get(name)
176	if !ok {
177		return nil, fmt.Errorf("mcp '%s' not available", name)
178	}
179
180	m := config.Get().MCP[name]
181	state, _ := mcpStates.Get(name)
182
183	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
184	defer cancel()
185	err := c.Ping(pingCtx)
186	if err == nil {
187		return c, nil
188	}
189	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
190
191	c, err = createAndInitializeClient(ctx, name, m)
192	if err != nil {
193		return nil, err
194	}
195
196	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
197	mcpClients.Set(name, c)
198	return c, nil
199}
200
201func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
202	sessionID, messageID := tools.GetContextValues(ctx)
203	if sessionID == "" || messageID == "" {
204		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
205	}
206	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
207	p := b.permissions.Request(
208		permission.CreatePermissionRequest{
209			SessionID:   sessionID,
210			ToolCallID:  params.ID,
211			Path:        b.workingDir,
212			ToolName:    b.Info().Name,
213			Action:      "execute",
214			Description: permissionDescription,
215			Params:      params.Input,
216		},
217	)
218	if !p {
219		return tools.ToolResponse{}, permission.ErrorPermissionDenied
220	}
221
222	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
223}
224
225func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
226	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
227	if err != nil {
228		slog.Error("error listing tools", "error", err)
229		updateMCPState(name, MCPStateError, err, nil, 0)
230		c.Close()
231		mcpClients.Del(name)
232		return nil
233	}
234	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
235	for _, tool := range result.Tools {
236		mcpTools = append(mcpTools, &McpTool{
237			mcpName:     name,
238			tool:        tool,
239			permissions: permissions,
240			workingDir:  workingDir,
241		})
242	}
243	return mcpTools
244}
245
246// SubscribeMCPEvents returns a channel for MCP events
247func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
248	return mcpBroker.Subscribe(ctx)
249}
250
251// GetMCPStates returns the current state of all MCP clients
252func GetMCPStates() map[string]MCPClientInfo {
253	return maps.Collect(mcpStates.Seq2())
254}
255
256// GetMCPState returns the state of a specific MCP client
257func GetMCPState(name string) (MCPClientInfo, bool) {
258	return mcpStates.Get(name)
259}
260
261// updateMCPState updates the state of an MCP client and publishes an event
262func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
263	info := MCPClientInfo{
264		Name:      name,
265		State:     state,
266		Error:     err,
267		Client:    client,
268		ToolCount: toolCount,
269	}
270	if state == MCPStateConnected {
271		info.ConnectedAt = time.Now()
272	}
273	mcpStates.Set(name, info)
274
275	// Publish state change event
276	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
277		Type:      MCPEventStateChanged,
278		Name:      name,
279		State:     state,
280		Error:     err,
281		ToolCount: toolCount,
282	})
283}
284
285// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
286func CloseMCPClients() error {
287	var errs []error
288	for c := range mcpClients.Seq() {
289		if err := c.Close(); err != nil {
290			errs = append(errs, err)
291		}
292	}
293	mcpBroker.Shutdown()
294	return errors.Join(errs...)
295}
296
297var mcpInitRequest = mcp.InitializeRequest{
298	Params: mcp.InitializeParams{
299		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
300		ClientInfo: mcp.Implementation{
301			Name:    "Crush",
302			Version: version.Version,
303		},
304	},
305}
306
307func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
308	var wg sync.WaitGroup
309	result := csync.NewSlice[tools.BaseTool]()
310
311	// Initialize states for all configured MCPs
312	for name, m := range cfg.MCP {
313		if m.Disabled {
314			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
315			slog.Debug("skipping disabled mcp", "name", name)
316			continue
317		}
318
319		// Set initial starting state
320		updateMCPState(name, MCPStateStarting, nil, nil, 0)
321
322		wg.Add(1)
323		go func(name string, m config.MCPConfig) {
324			defer func() {
325				wg.Done()
326				if r := recover(); r != nil {
327					var err error
328					switch v := r.(type) {
329					case error:
330						err = v
331					case string:
332						err = fmt.Errorf("panic: %s", v)
333					default:
334						err = fmt.Errorf("panic: %v", v)
335					}
336					updateMCPState(name, MCPStateError, err, nil, 0)
337					slog.Error("panic in mcp client initialization", "error", err, "name", name)
338				}
339			}()
340
341			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
342			defer cancel()
343			c, err := createAndInitializeClient(ctx, name, m)
344			if err != nil {
345				return
346			}
347			mcpClients.Set(name, c)
348
349			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
350			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
351			result.Append(tools...)
352		}(name, m)
353	}
354	wg.Wait()
355	return slices.Collect(result.Seq())
356}
357
358func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
359	c, err := createMcpClient(name, m)
360	if err != nil {
361		updateMCPState(name, MCPStateError, err, nil, 0)
362		slog.Error("error creating mcp client", "error", err, "name", name)
363		return nil, err
364	}
365	// Only call Start() for non-stdio clients, as stdio clients auto-start
366	if m.Type != config.MCPStdio {
367		if err := c.Start(ctx); err != nil {
368			updateMCPState(name, MCPStateError, err, nil, 0)
369			slog.Error("error starting mcp client", "error", err, "name", name)
370			_ = c.Close()
371			return nil, err
372		}
373	}
374	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
375		updateMCPState(name, MCPStateError, err, nil, 0)
376		slog.Error("error initializing mcp client", "error", err, "name", name)
377		_ = c.Close()
378		return nil, err
379	}
380
381	slog.Info("Initialized mcp client", "name", name)
382	return c, nil
383}
384
385func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
386	switch m.Type {
387	case config.MCPStdio:
388		if strings.TrimSpace(m.Command) == "" {
389			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
390		}
391		return client.NewStdioMCPClientWithOptions(
392			m.Command,
393			m.ResolvedEnv(),
394			m.Args,
395			transport.WithCommandLogger(mcpLogger{name: name}),
396		)
397	case config.MCPHttp:
398		if strings.TrimSpace(m.URL) == "" {
399			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
400		}
401		return client.NewStreamableHttpClient(
402			m.URL,
403			transport.WithHTTPHeaders(m.ResolvedHeaders()),
404			transport.WithHTTPLogger(mcpLogger{name: name}),
405		)
406	case config.MCPSse:
407		if strings.TrimSpace(m.URL) == "" {
408			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
409		}
410		return client.NewSSEMCPClient(
411			m.URL,
412			client.WithHeaders(m.ResolvedHeaders()),
413			transport.WithSSELogger(mcpLogger{name: name}),
414		)
415	default:
416		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
417	}
418}
419
420// for MCP's clients.
421type mcpLogger struct{ name string }
422
423func (l mcpLogger) Errorf(format string, v ...any) {
424	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
425}
426
427func (l mcpLogger) Infof(format string, v ...any) {
428	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
429}
430
431func mcpTimeout(m config.MCPConfig) time.Duration {
432	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
433}