mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"slices"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/mark3labs/mcp-go/client"
 22	"github.com/mark3labs/mcp-go/client/transport"
 23	"github.com/mark3labs/mcp-go/mcp"
 24)
 25
 26// MCPState represents the current state of an MCP client
 27type MCPState int
 28
 29const (
 30	MCPStateDisabled MCPState = iota
 31	MCPStateStarting
 32	MCPStateConnected
 33	MCPStateError
 34)
 35
 36func (s MCPState) String() string {
 37	switch s {
 38	case MCPStateDisabled:
 39		return "disabled"
 40	case MCPStateStarting:
 41		return "starting"
 42	case MCPStateConnected:
 43		return "connected"
 44	case MCPStateError:
 45		return "error"
 46	default:
 47		return "unknown"
 48	}
 49}
 50
 51// MCPEventType represents the type of MCP event
 52type MCPEventType string
 53
 54const (
 55	MCPEventStateChanged MCPEventType = "state_changed"
 56)
 57
 58// MCPEvent represents an event in the MCP system
 59type MCPEvent struct {
 60	Type      MCPEventType
 61	Name      string
 62	State     MCPState
 63	Error     error
 64	ToolCount int
 65}
 66
 67// MCPClientInfo holds information about an MCP client's state
 68type MCPClientInfo struct {
 69	Name        string
 70	State       MCPState
 71	Error       error
 72	Client      *client.Client
 73	ToolCount   int
 74	ConnectedAt time.Time
 75}
 76
 77var (
 78	mcpToolsOnce sync.Once
 79	mcpTools     []tools.BaseTool
 80	mcpClients   = csync.NewMap[string, *client.Client]()
 81	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 82	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 83)
 84
 85type McpTool struct {
 86	mcpName     string
 87	tool        mcp.Tool
 88	permissions permission.Service
 89	workingDir  string
 90}
 91
 92func (b *McpTool) Name() string {
 93	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 94}
 95
 96func (b *McpTool) Info() tools.ToolInfo {
 97	required := b.tool.InputSchema.Required
 98	if required == nil {
 99		required = make([]string, 0)
100	}
101	parameters := b.tool.InputSchema.Properties
102	if parameters == nil {
103		parameters = make(map[string]any)
104	}
105	return tools.ToolInfo{
106		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
107		Description: b.tool.Description,
108		Parameters:  parameters,
109		Required:    required,
110	}
111}
112
113func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
114	var args map[string]any
115	if err := json.Unmarshal([]byte(input), &args); err != nil {
116		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
117	}
118
119	c, err := getOrRenewClient(ctx, name)
120	if err != nil {
121		return tools.NewTextErrorResponse(err.Error()), nil
122	}
123	result, err := c.CallTool(ctx, mcp.CallToolRequest{
124		Params: mcp.CallToolParams{
125			Name:      toolName,
126			Arguments: args,
127		},
128	})
129	if err != nil {
130		return tools.NewTextErrorResponse(err.Error()), nil
131	}
132
133	output := make([]string, 0, len(result.Content))
134	for _, v := range result.Content {
135		if v, ok := v.(mcp.TextContent); ok {
136			output = append(output, v.Text)
137		} else {
138			output = append(output, fmt.Sprintf("%v", v))
139		}
140	}
141	return tools.NewTextResponse(strings.Join(output, "\n")), nil
142}
143
144func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
145	c, ok := mcpClients.Get(name)
146	if !ok {
147		return nil, fmt.Errorf("mcp '%s' not available", name)
148	}
149
150	m := config.Get().MCP[name]
151	state, _ := mcpStates.Get(name)
152
153	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
154	defer cancel()
155	err := c.Ping(pingCtx)
156	if err == nil {
157		return c, nil
158	}
159	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
160
161	c, err = createAndInitializeClient(ctx, name, m)
162	if err != nil {
163		return nil, err
164	}
165
166	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
167	mcpClients.Set(name, c)
168	return c, nil
169}
170
171func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
172	sessionID, messageID := tools.GetContextValues(ctx)
173	if sessionID == "" || messageID == "" {
174		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
175	}
176	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
177	p := b.permissions.Request(
178		permission.CreatePermissionRequest{
179			SessionID:   sessionID,
180			ToolCallID:  params.ID,
181			Path:        b.workingDir,
182			ToolName:    b.Info().Name,
183			Action:      "execute",
184			Description: permissionDescription,
185			Params:      params.Input,
186		},
187	)
188	if !p {
189		return tools.ToolResponse{}, permission.ErrorPermissionDenied
190	}
191
192	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
193}
194
195func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
196	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
197	if err != nil {
198		slog.Error("error listing tools", "error", err)
199		updateMCPState(name, MCPStateError, err, nil, 0)
200		c.Close()
201		mcpClients.Del(name)
202		return nil
203	}
204	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
205	for _, tool := range result.Tools {
206		mcpTools = append(mcpTools, &McpTool{
207			mcpName:     name,
208			tool:        tool,
209			permissions: permissions,
210			workingDir:  workingDir,
211		})
212	}
213	return mcpTools
214}
215
216// SubscribeMCPEvents returns a channel for MCP events
217func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
218	return mcpBroker.Subscribe(ctx)
219}
220
221// GetMCPStates returns the current state of all MCP clients
222func GetMCPStates() map[string]MCPClientInfo {
223	return maps.Collect(mcpStates.Seq2())
224}
225
226// GetMCPState returns the state of a specific MCP client
227func GetMCPState(name string) (MCPClientInfo, bool) {
228	return mcpStates.Get(name)
229}
230
231// updateMCPState updates the state of an MCP client and publishes an event
232func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
233	info := MCPClientInfo{
234		Name:      name,
235		State:     state,
236		Error:     err,
237		Client:    client,
238		ToolCount: toolCount,
239	}
240	if state == MCPStateConnected {
241		info.ConnectedAt = time.Now()
242	}
243	mcpStates.Set(name, info)
244
245	// Publish state change event
246	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
247		Type:      MCPEventStateChanged,
248		Name:      name,
249		State:     state,
250		Error:     err,
251		ToolCount: toolCount,
252	})
253}
254
255// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
256func CloseMCPClients() {
257	for c := range mcpClients.Seq() {
258		_ = c.Close()
259	}
260	mcpBroker.Shutdown()
261}
262
263var mcpInitRequest = mcp.InitializeRequest{
264	Params: mcp.InitializeParams{
265		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
266		ClientInfo: mcp.Implementation{
267			Name:    "Crush",
268			Version: version.Version,
269		},
270	},
271}
272
273func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
274	var wg sync.WaitGroup
275	result := csync.NewSlice[tools.BaseTool]()
276
277	// Initialize states for all configured MCPs
278	for name, m := range cfg.MCP {
279		if m.Disabled {
280			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
281			slog.Debug("skipping disabled mcp", "name", name)
282			continue
283		}
284
285		// Set initial starting state
286		updateMCPState(name, MCPStateStarting, nil, nil, 0)
287
288		wg.Add(1)
289		go func(name string, m config.MCPConfig) {
290			defer func() {
291				wg.Done()
292				if r := recover(); r != nil {
293					var err error
294					switch v := r.(type) {
295					case error:
296						err = v
297					case string:
298						err = fmt.Errorf("panic: %s", v)
299					default:
300						err = fmt.Errorf("panic: %v", v)
301					}
302					updateMCPState(name, MCPStateError, err, nil, 0)
303					slog.Error("panic in mcp client initialization", "error", err, "name", name)
304				}
305			}()
306
307			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
308			defer cancel()
309			c, err := createAndInitializeClient(ctx, name, m)
310			if err != nil {
311				return
312			}
313			mcpClients.Set(name, c)
314
315			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
316			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
317			result.Append(tools...)
318		}(name, m)
319	}
320	wg.Wait()
321	return slices.Collect(result.Seq())
322}
323
324func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
325	c, err := createMcpClient(m)
326	if err != nil {
327		updateMCPState(name, MCPStateError, err, nil, 0)
328		slog.Error("error creating mcp client", "error", err, "name", name)
329		return nil, err
330	}
331	// Only call Start() for non-stdio clients, as stdio clients auto-start
332	if m.Type != config.MCPStdio {
333		if err := c.Start(ctx); err != nil {
334			updateMCPState(name, MCPStateError, err, nil, 0)
335			slog.Error("error starting mcp client", "error", err, "name", name)
336			_ = c.Close()
337			return nil, err
338		}
339	}
340	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
341		updateMCPState(name, MCPStateError, err, nil, 0)
342		slog.Error("error initializing mcp client", "error", err, "name", name)
343		_ = c.Close()
344		return nil, err
345	}
346
347	slog.Info("Initialized mcp client", "name", name)
348	return c, nil
349}
350
351func createMcpClient(m config.MCPConfig) (*client.Client, error) {
352	switch m.Type {
353	case config.MCPStdio:
354		if strings.TrimSpace(m.Command) == "" {
355			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
356		}
357		return client.NewStdioMCPClientWithOptions(
358			m.Command,
359			m.ResolvedEnv(),
360			m.Args,
361			transport.WithCommandLogger(mcpLogger{}),
362		)
363	case config.MCPHttp:
364		if strings.TrimSpace(m.URL) == "" {
365			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
366		}
367		return client.NewStreamableHttpClient(
368			m.URL,
369			transport.WithHTTPHeaders(m.ResolvedHeaders()),
370			transport.WithHTTPLogger(mcpLogger{}),
371		)
372	case config.MCPSse:
373		if strings.TrimSpace(m.URL) == "" {
374			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
375		}
376		return client.NewSSEMCPClient(
377			m.URL,
378			client.WithHeaders(m.ResolvedHeaders()),
379			transport.WithSSELogger(mcpLogger{}),
380		)
381	default:
382		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
383	}
384}
385
386// for MCP's clients.
387type mcpLogger struct{}
388
389func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
390func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
391
392func mcpTimeout(m config.MCPConfig) time.Duration {
393	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
394}