mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/mark3labs/mcp-go/client"
 23	"github.com/mark3labs/mcp-go/client/transport"
 24	"github.com/mark3labs/mcp-go/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged MCPEventType = "state_changed"
 57)
 58
 59// MCPEvent represents an event in the MCP system
 60type MCPEvent struct {
 61	Type      MCPEventType
 62	Name      string
 63	State     MCPState
 64	Error     error
 65	ToolCount int
 66}
 67
 68// MCPClientInfo holds information about an MCP client's state
 69type MCPClientInfo struct {
 70	Name        string
 71	State       MCPState
 72	Error       error
 73	Client      *client.Client
 74	ToolCount   int
 75	ConnectedAt time.Time
 76}
 77
 78var (
 79	mcpToolsOnce sync.Once
 80	mcpTools     []tools.BaseTool
 81	mcpClients   = csync.NewMap[string, *client.Client]()
 82	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 83	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 84)
 85
 86type McpTool struct {
 87	mcpName     string
 88	tool        mcp.Tool
 89	permissions permission.Service
 90	workingDir  string
 91}
 92
 93func (b *McpTool) Name() string {
 94	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 95}
 96
 97func (b *McpTool) Info() tools.ToolInfo {
 98	required := b.tool.InputSchema.Required
 99	if required == nil {
100		required = make([]string, 0)
101	}
102	parameters := b.tool.InputSchema.Properties
103	if parameters == nil {
104		parameters = make(map[string]any)
105	}
106	return tools.ToolInfo{
107		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
108		Description: b.tool.Description,
109		Parameters:  parameters,
110		Required:    required,
111	}
112}
113
114func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
115	var args map[string]any
116	if err := json.Unmarshal([]byte(input), &args); err != nil {
117		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
118	}
119
120	c, err := getOrRenewClient(ctx, name)
121	if err != nil {
122		return tools.NewTextErrorResponse(err.Error()), nil
123	}
124	result, err := c.CallTool(ctx, mcp.CallToolRequest{
125		Params: mcp.CallToolParams{
126			Name:      toolName,
127			Arguments: args,
128		},
129	})
130	if err != nil {
131		return tools.NewTextErrorResponse(err.Error()), nil
132	}
133
134	output := make([]string, 0, len(result.Content))
135	for _, v := range result.Content {
136		if v, ok := v.(mcp.TextContent); ok {
137			output = append(output, v.Text)
138		} else {
139			output = append(output, fmt.Sprintf("%v", v))
140		}
141	}
142	return tools.NewTextResponse(strings.Join(output, "\n")), nil
143}
144
145func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
146	c, ok := mcpClients.Get(name)
147	if !ok {
148		return nil, fmt.Errorf("mcp '%s' not available", name)
149	}
150
151	m := config.Get().MCP[name]
152	state, _ := mcpStates.Get(name)
153
154	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
155	defer cancel()
156	err := c.Ping(pingCtx)
157	if err == nil {
158		return c, nil
159	}
160	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
161
162	c, err = createAndInitializeClient(ctx, name, m)
163	if err != nil {
164		return nil, err
165	}
166
167	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
168	mcpClients.Set(name, c)
169	return c, nil
170}
171
172func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
173	sessionID, messageID := tools.GetContextValues(ctx)
174	if sessionID == "" || messageID == "" {
175		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
176	}
177	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
178	p := b.permissions.Request(
179		permission.CreatePermissionRequest{
180			SessionID:   sessionID,
181			ToolCallID:  params.ID,
182			Path:        b.workingDir,
183			ToolName:    b.Info().Name,
184			Action:      "execute",
185			Description: permissionDescription,
186			Params:      params.Input,
187		},
188	)
189	if !p {
190		return tools.ToolResponse{}, permission.ErrorPermissionDenied
191	}
192
193	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
194}
195
196func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
197	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
198	if err != nil {
199		slog.Error("error listing tools", "error", err)
200		updateMCPState(name, MCPStateError, err, nil, 0)
201		c.Close()
202		mcpClients.Del(name)
203		return nil
204	}
205	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
206	for _, tool := range result.Tools {
207		mcpTools = append(mcpTools, &McpTool{
208			mcpName:     name,
209			tool:        tool,
210			permissions: permissions,
211			workingDir:  workingDir,
212		})
213	}
214	return mcpTools
215}
216
217// SubscribeMCPEvents returns a channel for MCP events
218func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
219	return mcpBroker.Subscribe(ctx)
220}
221
222// GetMCPStates returns the current state of all MCP clients
223func GetMCPStates() map[string]MCPClientInfo {
224	return maps.Collect(mcpStates.Seq2())
225}
226
227// GetMCPState returns the state of a specific MCP client
228func GetMCPState(name string) (MCPClientInfo, bool) {
229	return mcpStates.Get(name)
230}
231
232// updateMCPState updates the state of an MCP client and publishes an event
233func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
234	info := MCPClientInfo{
235		Name:      name,
236		State:     state,
237		Error:     err,
238		Client:    client,
239		ToolCount: toolCount,
240	}
241	if state == MCPStateConnected {
242		info.ConnectedAt = time.Now()
243	}
244	mcpStates.Set(name, info)
245
246	// Publish state change event
247	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
248		Type:      MCPEventStateChanged,
249		Name:      name,
250		State:     state,
251		Error:     err,
252		ToolCount: toolCount,
253	})
254}
255
256// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
257func CloseMCPClients() error {
258	var errs []error
259	for c := range mcpClients.Seq() {
260		if err := c.Close(); err != nil {
261			errs = append(errs, err)
262		}
263	}
264	mcpBroker.Shutdown()
265	return errors.Join(errs...)
266}
267
268var mcpInitRequest = mcp.InitializeRequest{
269	Params: mcp.InitializeParams{
270		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
271		ClientInfo: mcp.Implementation{
272			Name:    "Crush",
273			Version: version.Version,
274		},
275	},
276}
277
278func doGetMCPTools(permissions permission.Service, cfg *config.Config) []tools.BaseTool {
279	var wg sync.WaitGroup
280	result := csync.NewSlice[tools.BaseTool]()
281
282	// Initialize states for all configured MCPs
283	for name, m := range cfg.MCP {
284		if m.Disabled {
285			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
286			slog.Debug("skipping disabled mcp", "name", name)
287			continue
288		}
289
290		// Set initial starting state
291		updateMCPState(name, MCPStateStarting, nil, nil, 0)
292
293		wg.Add(1)
294		go func(name string, m config.MCPConfig) {
295			defer func() {
296				wg.Done()
297				if r := recover(); r != nil {
298					var err error
299					switch v := r.(type) {
300					case error:
301						err = v
302					case string:
303						err = fmt.Errorf("panic: %s", v)
304					default:
305						err = fmt.Errorf("panic: %v", v)
306					}
307					updateMCPState(name, MCPStateError, err, nil, 0)
308					slog.Error("panic in mcp client initialization", "error", err, "name", name)
309				}
310			}()
311
312			ctx, cancel := context.WithTimeout(context.Background(), mcpTimeout(m))
313			defer cancel()
314			c, err := createAndInitializeClient(ctx, name, m)
315			if err != nil {
316				return
317			}
318			mcpClients.Set(name, c)
319
320			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
321			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
322			result.Append(tools...)
323		}(name, m)
324	}
325	wg.Wait()
326	return slices.Collect(result.Seq())
327}
328
329func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
330	c, err := createMcpClient(m)
331	if err != nil {
332		updateMCPState(name, MCPStateError, err, nil, 0)
333		slog.Error("error creating mcp client", "error", err, "name", name)
334		return nil, err
335	}
336	// Only call Start() for non-stdio clients, as stdio clients auto-start
337	if m.Type != config.MCPStdio {
338		if err := c.Start(ctx); err != nil {
339			updateMCPState(name, MCPStateError, err, nil, 0)
340			slog.Error("error starting mcp client", "error", err, "name", name)
341			_ = c.Close()
342			return nil, err
343		}
344	}
345	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
346		updateMCPState(name, MCPStateError, err, nil, 0)
347		slog.Error("error initializing mcp client", "error", err, "name", name)
348		_ = c.Close()
349		return nil, err
350	}
351
352	slog.Info("Initialized mcp client", "name", name)
353	return c, nil
354}
355
356func createMcpClient(m config.MCPConfig) (*client.Client, error) {
357	switch m.Type {
358	case config.MCPStdio:
359		if strings.TrimSpace(m.Command) == "" {
360			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
361		}
362		return client.NewStdioMCPClientWithOptions(
363			m.Command,
364			m.ResolvedEnv(),
365			m.Args,
366			transport.WithCommandLogger(mcpLogger{}),
367		)
368	case config.MCPHttp:
369		if strings.TrimSpace(m.URL) == "" {
370			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
371		}
372		return client.NewStreamableHttpClient(
373			m.URL,
374			transport.WithHTTPHeaders(m.ResolvedHeaders()),
375			transport.WithHTTPLogger(mcpLogger{}),
376		)
377	case config.MCPSse:
378		if strings.TrimSpace(m.URL) == "" {
379			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
380		}
381		return client.NewSSEMCPClient(
382			m.URL,
383			client.WithHeaders(m.ResolvedHeaders()),
384			transport.WithSSELogger(mcpLogger{}),
385		)
386	default:
387		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
388	}
389}
390
391// for MCP's clients.
392type mcpLogger struct{}
393
394func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
395func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
396
397func mcpTimeout(m config.MCPConfig) time.Duration {
398	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
399}