mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/pubsub"
 17	"github.com/charmbracelet/crush/internal/version"
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23// MCPState represents the current state of an MCP client
 24type MCPState int
 25
 26const (
 27	MCPStateDisabled MCPState = iota
 28	MCPStateStarting
 29	MCPStateConnected
 30	MCPStateError
 31)
 32
 33func (s MCPState) String() string {
 34	switch s {
 35	case MCPStateDisabled:
 36		return "disabled"
 37	case MCPStateStarting:
 38		return "starting"
 39	case MCPStateConnected:
 40		return "connected"
 41	case MCPStateError:
 42		return "error"
 43	default:
 44		return "unknown"
 45	}
 46}
 47
 48// MCPEventType represents the type of MCP event
 49type MCPEventType string
 50
 51const (
 52	MCPEventStateChanged MCPEventType = "state_changed"
 53)
 54
 55// MCPEvent represents an event in the MCP system
 56type MCPEvent struct {
 57	Type      MCPEventType
 58	Name      string
 59	State     MCPState
 60	Error     error
 61	ToolCount int
 62}
 63
 64// MCPClientInfo holds information about an MCP client's state
 65type MCPClientInfo struct {
 66	Name        string
 67	State       MCPState
 68	Error       error
 69	Client      *client.Client
 70	ToolCount   int
 71	ConnectedAt time.Time
 72}
 73
 74var (
 75	mcpToolsOnce sync.Once
 76	mcpTools     []tools.BaseTool
 77	mcpClients   = csync.NewMap[string, *client.Client]()
 78	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 79	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 80)
 81
 82type McpTool struct {
 83	mcpName     string
 84	tool        mcp.Tool
 85	permissions permission.Service
 86	workingDir  string
 87}
 88
 89func (b *McpTool) Name() string {
 90	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 91}
 92
 93func (b *McpTool) Info() tools.ToolInfo {
 94	required := b.tool.InputSchema.Required
 95	if required == nil {
 96		required = make([]string, 0)
 97	}
 98	return tools.ToolInfo{
 99		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
100		Description: b.tool.Description,
101		Parameters:  b.tool.InputSchema.Properties,
102		Required:    required,
103	}
104}
105
106func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
107	var args map[string]any
108	if err := json.Unmarshal([]byte(input), &args); err != nil {
109		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
110	}
111	c, ok := mcpClients.Get(name)
112	if !ok {
113		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
114	}
115	result, err := c.CallTool(ctx, mcp.CallToolRequest{
116		Params: mcp.CallToolParams{
117			Name:      toolName,
118			Arguments: args,
119		},
120	})
121	if err != nil {
122		return tools.NewTextErrorResponse(err.Error()), nil
123	}
124
125	output := ""
126	for _, v := range result.Content {
127		if v, ok := v.(mcp.TextContent); ok {
128			output = v.Text
129		} else {
130			output = fmt.Sprintf("%v", v)
131		}
132	}
133
134	return tools.NewTextResponse(output), nil
135}
136
137func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
138	sessionID, messageID := tools.GetContextValues(ctx)
139	if sessionID == "" || messageID == "" {
140		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
141	}
142	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
143	p := b.permissions.Request(
144		permission.CreatePermissionRequest{
145			SessionID:   sessionID,
146			ToolCallID:  params.ID,
147			Path:        b.workingDir,
148			ToolName:    b.Info().Name,
149			Action:      "execute",
150			Description: permissionDescription,
151			Params:      params.Input,
152		},
153	)
154	if !p {
155		return tools.ToolResponse{}, permission.ErrorPermissionDenied
156	}
157
158	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
159}
160
161func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
162	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
163	if err != nil {
164		slog.Error("error listing tools", "error", err)
165		updateMCPState(name, MCPStateError, err, nil, 0)
166		c.Close()
167		mcpClients.Del(name)
168		return nil
169	}
170	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
171	for _, tool := range result.Tools {
172		mcpTools = append(mcpTools, &McpTool{
173			mcpName:     name,
174			tool:        tool,
175			permissions: permissions,
176			workingDir:  workingDir,
177		})
178	}
179	return mcpTools
180}
181
182// SubscribeMCPEvents returns a channel for MCP events
183func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
184	return mcpBroker.Subscribe(ctx)
185}
186
187// GetMCPStates returns the current state of all MCP clients
188func GetMCPStates() map[string]MCPClientInfo {
189	states := make(map[string]MCPClientInfo)
190	for name, info := range mcpStates.Seq2() {
191		states[name] = info
192	}
193	return states
194}
195
196// GetMCPState returns the state of a specific MCP client
197func GetMCPState(name string) (MCPClientInfo, bool) {
198	return mcpStates.Get(name)
199}
200
201// updateMCPState updates the state of an MCP client and publishes an event
202func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
203	info := MCPClientInfo{
204		Name:      name,
205		State:     state,
206		Error:     err,
207		Client:    client,
208		ToolCount: toolCount,
209	}
210	if state == MCPStateConnected {
211		info.ConnectedAt = time.Now()
212	}
213	mcpStates.Set(name, info)
214
215	// Publish state change event
216	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
217		Type:      MCPEventStateChanged,
218		Name:      name,
219		State:     state,
220		Error:     err,
221		ToolCount: toolCount,
222	})
223}
224
225// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
226func CloseMCPClients() {
227	for c := range mcpClients.Seq() {
228		_ = c.Close()
229	}
230	mcpBroker.Shutdown()
231}
232
233var mcpInitRequest = mcp.InitializeRequest{
234	Params: mcp.InitializeParams{
235		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
236		ClientInfo: mcp.Implementation{
237			Name:    "Crush",
238			Version: version.Version,
239		},
240	},
241}
242
243func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
244	var wg sync.WaitGroup
245	result := csync.NewSlice[tools.BaseTool]()
246
247	// Initialize states for all configured MCPs
248	for name, m := range cfg.MCP {
249		if m.Disabled {
250			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
251			slog.Debug("skipping disabled mcp", "name", name)
252			continue
253		}
254
255		// Set initial starting state
256		updateMCPState(name, MCPStateStarting, nil, nil, 0)
257
258		wg.Add(1)
259		go func(name string, m config.MCPConfig) {
260			defer func() {
261				wg.Done()
262				if r := recover(); r != nil {
263					var err error
264					switch v := r.(type) {
265					case error:
266						err = v
267					case string:
268						err = fmt.Errorf("panic: %s", v)
269					default:
270						err = fmt.Errorf("panic: %v", v)
271					}
272					updateMCPState(name, MCPStateError, err, nil, 0)
273					slog.Error("panic in mcp client initialization", "error", err, "name", name)
274				}
275			}()
276
277			c, err := createMcpClient(m)
278			if err != nil {
279				updateMCPState(name, MCPStateError, err, nil, 0)
280				slog.Error("error creating mcp client", "error", err, "name", name)
281				return
282			}
283			if err := c.Start(ctx); err != nil {
284				updateMCPState(name, MCPStateError, err, nil, 0)
285				slog.Error("error starting mcp client", "error", err, "name", name)
286				_ = c.Close()
287				return
288			}
289			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
290				updateMCPState(name, MCPStateError, err, nil, 0)
291				slog.Error("error initializing mcp client", "error", err, "name", name)
292				_ = c.Close()
293				return
294			}
295
296			slog.Info("Initialized mcp client", "name", name)
297			mcpClients.Set(name, c)
298
299			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
300			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
301			result.Append(tools...)
302		}(name, m)
303	}
304	wg.Wait()
305	return slices.Collect(result.Seq())
306}
307
308func createMcpClient(m config.MCPConfig) (*client.Client, error) {
309	switch m.Type {
310	case config.MCPStdio:
311		return client.NewStdioMCPClient(
312			m.Command,
313			m.ResolvedEnv(),
314			m.Args...,
315		)
316	case config.MCPHttp:
317		return client.NewStreamableHttpClient(
318			m.URL,
319			transport.WithHTTPHeaders(m.ResolvedHeaders()),
320			transport.WithLogger(mcpHTTPLogger{}),
321		)
322	case config.MCPSse:
323		return client.NewSSEMCPClient(
324			m.URL,
325			client.WithHeaders(m.ResolvedHeaders()),
326		)
327	default:
328		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
329	}
330}
331
332// for MCP's HTTP client.
333type mcpHTTPLogger struct{}
334
335func (l mcpHTTPLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
336func (l mcpHTTPLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }