mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/permission"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/version"
 19	"github.com/mark3labs/mcp-go/client"
 20	"github.com/mark3labs/mcp-go/client/transport"
 21	"github.com/mark3labs/mcp-go/mcp"
 22)
 23
 24// MCPState represents the current state of an MCP client
 25type MCPState int
 26
 27const (
 28	MCPStateDisabled MCPState = iota
 29	MCPStateStarting
 30	MCPStateConnected
 31	MCPStateError
 32)
 33
 34func (s MCPState) String() string {
 35	switch s {
 36	case MCPStateDisabled:
 37		return "disabled"
 38	case MCPStateStarting:
 39		return "starting"
 40	case MCPStateConnected:
 41		return "connected"
 42	case MCPStateError:
 43		return "error"
 44	default:
 45		return "unknown"
 46	}
 47}
 48
 49// MCPEventType represents the type of MCP event
 50type MCPEventType string
 51
 52const (
 53	MCPEventStateChanged MCPEventType = "state_changed"
 54)
 55
 56// MCPEvent represents an event in the MCP system
 57type MCPEvent struct {
 58	Type      MCPEventType
 59	Name      string
 60	State     MCPState
 61	Error     error
 62	ToolCount int
 63}
 64
 65// MCPClientInfo holds information about an MCP client's state
 66type MCPClientInfo struct {
 67	Name        string
 68	State       MCPState
 69	Error       error
 70	Client      *client.Client
 71	ToolCount   int
 72	ConnectedAt time.Time
 73}
 74
 75var (
 76	mcpToolsOnce sync.Once
 77	mcpTools     []tools.BaseTool
 78	mcpClients   = csync.NewMap[string, *client.Client]()
 79	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 80	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 81)
 82
 83type McpTool struct {
 84	mcpName     string
 85	tool        mcp.Tool
 86	permissions permission.Service
 87	workingDir  string
 88}
 89
 90func (b *McpTool) Name() string {
 91	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 92}
 93
 94func (b *McpTool) Info() tools.ToolInfo {
 95	required := b.tool.InputSchema.Required
 96	if required == nil {
 97		required = make([]string, 0)
 98	}
 99	return tools.ToolInfo{
100		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
101		Description: b.tool.Description,
102		Parameters:  b.tool.InputSchema.Properties,
103		Required:    required,
104	}
105}
106
107func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
108	var args map[string]any
109	if err := json.Unmarshal([]byte(input), &args); err != nil {
110		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
111	}
112	c, ok := mcpClients.Get(name)
113	if !ok {
114		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
115	}
116	result, err := c.CallTool(ctx, mcp.CallToolRequest{
117		Params: mcp.CallToolParams{
118			Name:      toolName,
119			Arguments: args,
120		},
121	})
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125
126	var output strings.Builder
127	for _, v := range result.Content {
128		if v, ok := v.(mcp.TextContent); ok {
129			output.WriteString(v.Text)
130		} else {
131			_, _ = fmt.Fprintf(&output, "%v: ", v)
132		}
133	}
134
135	return tools.NewTextResponse(output.String()), nil
136}
137
138func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
139	sessionID, messageID := tools.GetContextValues(ctx)
140	if sessionID == "" || messageID == "" {
141		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
142	}
143	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
144	p := b.permissions.Request(
145		permission.CreatePermissionRequest{
146			SessionID:   sessionID,
147			ToolCallID:  params.ID,
148			Path:        b.workingDir,
149			ToolName:    b.Info().Name,
150			Action:      "execute",
151			Description: permissionDescription,
152			Params:      params.Input,
153		},
154	)
155	if !p {
156		return tools.ToolResponse{}, permission.ErrorPermissionDenied
157	}
158
159	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
160}
161
162func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
163	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
164	if err != nil {
165		slog.Error("error listing tools", "error", err)
166		updateMCPState(name, MCPStateError, err, nil, 0)
167		c.Close()
168		mcpClients.Del(name)
169		return nil
170	}
171	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
172	for _, tool := range result.Tools {
173		mcpTools = append(mcpTools, &McpTool{
174			mcpName:     name,
175			tool:        tool,
176			permissions: permissions,
177			workingDir:  workingDir,
178		})
179	}
180	return mcpTools
181}
182
183// SubscribeMCPEvents returns a channel for MCP events
184func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
185	return mcpBroker.Subscribe(ctx)
186}
187
188// GetMCPStates returns the current state of all MCP clients
189func GetMCPStates() map[string]MCPClientInfo {
190	states := make(map[string]MCPClientInfo)
191	for name, info := range mcpStates.Seq2() {
192		states[name] = info
193	}
194	return states
195}
196
197// GetMCPState returns the state of a specific MCP client
198func GetMCPState(name string) (MCPClientInfo, bool) {
199	return mcpStates.Get(name)
200}
201
202// updateMCPState updates the state of an MCP client and publishes an event
203func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
204	info := MCPClientInfo{
205		Name:      name,
206		State:     state,
207		Error:     err,
208		Client:    client,
209		ToolCount: toolCount,
210	}
211	if state == MCPStateConnected {
212		info.ConnectedAt = time.Now()
213	}
214	mcpStates.Set(name, info)
215
216	// Publish state change event
217	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
218		Type:      MCPEventStateChanged,
219		Name:      name,
220		State:     state,
221		Error:     err,
222		ToolCount: toolCount,
223	})
224}
225
226// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
227func CloseMCPClients() {
228	for c := range mcpClients.Seq() {
229		_ = c.Close()
230	}
231	mcpBroker.Shutdown()
232}
233
234var mcpInitRequest = mcp.InitializeRequest{
235	Params: mcp.InitializeParams{
236		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
237		ClientInfo: mcp.Implementation{
238			Name:    "Crush",
239			Version: version.Version,
240		},
241	},
242}
243
244func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
245	var wg sync.WaitGroup
246	result := csync.NewSlice[tools.BaseTool]()
247
248	// Initialize states for all configured MCPs
249	for name, m := range cfg.MCP {
250		if m.Disabled {
251			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
252			slog.Debug("skipping disabled mcp", "name", name)
253			continue
254		}
255
256		// Set initial starting state
257		updateMCPState(name, MCPStateStarting, nil, nil, 0)
258
259		wg.Add(1)
260		go func(name string, m config.MCPConfig) {
261			defer func() {
262				wg.Done()
263				if r := recover(); r != nil {
264					var err error
265					switch v := r.(type) {
266					case error:
267						err = v
268					case string:
269						err = fmt.Errorf("panic: %s", v)
270					default:
271						err = fmt.Errorf("panic: %v", v)
272					}
273					updateMCPState(name, MCPStateError, err, nil, 0)
274					slog.Error("panic in mcp client initialization", "error", err, "name", name)
275				}
276			}()
277
278			ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
279			defer cancel()
280			c, err := createMcpClient(m)
281			if err != nil {
282				updateMCPState(name, MCPStateError, err, nil, 0)
283				slog.Error("error creating mcp client", "error", err, "name", name)
284				return
285			}
286			if err := c.Start(ctx); err != nil {
287				updateMCPState(name, MCPStateError, err, nil, 0)
288				slog.Error("error starting mcp client", "error", err, "name", name)
289				_ = c.Close()
290				return
291			}
292			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
293				updateMCPState(name, MCPStateError, err, nil, 0)
294				slog.Error("error initializing mcp client", "error", err, "name", name)
295				_ = c.Close()
296				return
297			}
298
299			slog.Info("Initialized mcp client", "name", name)
300			mcpClients.Set(name, c)
301
302			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
303			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
304			result.Append(tools...)
305		}(name, m)
306	}
307	wg.Wait()
308	return slices.Collect(result.Seq())
309}
310
311func createMcpClient(m config.MCPConfig) (*client.Client, error) {
312	switch m.Type {
313	case config.MCPStdio:
314		return client.NewStdioMCPClientWithOptions(
315			m.Command,
316			m.ResolvedEnv(),
317			m.Args,
318			transport.WithCommandLogger(mcpLogger{}),
319		)
320	case config.MCPHttp:
321		return client.NewStreamableHttpClient(
322			m.URL,
323			transport.WithHTTPHeaders(m.ResolvedHeaders()),
324			transport.WithHTTPLogger(mcpLogger{}),
325		)
326	case config.MCPSse:
327		return client.NewSSEMCPClient(
328			m.URL,
329			client.WithHeaders(m.ResolvedHeaders()),
330			transport.WithSSELogger(mcpLogger{}),
331		)
332	default:
333		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
334	}
335}
336
337// for MCP's clients.
338type mcpLogger struct{}
339
340func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
341func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }