1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) MarshalText() ([]byte, error) {
 39	return []byte(s.String()), nil
 40}
 41
 42func (s *MCPState) UnmarshalText(data []byte) error {
 43	switch string(data) {
 44	case "disabled":
 45		*s = MCPStateDisabled
 46	case "starting":
 47		*s = MCPStateStarting
 48	case "connected":
 49		*s = MCPStateConnected
 50	case "error":
 51		*s = MCPStateError
 52	default:
 53		return fmt.Errorf("unknown mcp state: %s", data)
 54	}
 55	return nil
 56}
 57
 58func (s MCPState) String() string {
 59	switch s {
 60	case MCPStateDisabled:
 61		return "disabled"
 62	case MCPStateStarting:
 63		return "starting"
 64	case MCPStateConnected:
 65		return "connected"
 66	case MCPStateError:
 67		return "error"
 68	default:
 69		return "unknown"
 70	}
 71}
 72
 73// MCPEventType represents the type of MCP event
 74type MCPEventType string
 75
 76const (
 77	MCPEventStateChanged MCPEventType = "state_changed"
 78)
 79
 80func (t MCPEventType) MarshalText() ([]byte, error) {
 81	return []byte(t), nil
 82}
 83
 84func (t *MCPEventType) UnmarshalText(data []byte) error {
 85	*t = MCPEventType(data)
 86	return nil
 87}
 88
 89// MCPEvent represents an event in the MCP system
 90type MCPEvent struct {
 91	Type      MCPEventType `json:"type"`
 92	Name      string       `json:"name"`
 93	State     MCPState     `json:"state"`
 94	Error     error        `json:"error,omitempty"`
 95	ToolCount int          `json:"tool_count,omitempty"`
 96}
 97
 98// MCPClientInfo holds information about an MCP client's state
 99type MCPClientInfo struct {
100	Name        string
101	State       MCPState
102	Error       error
103	Client      *client.Client
104	ToolCount   int
105	ConnectedAt time.Time
106}
107
108var (
109	mcpToolsOnce sync.Once
110	mcpTools     []tools.BaseTool
111	mcpClients   = csync.NewMap[string, *client.Client]()
112	mcpStates    = csync.NewMap[string, MCPClientInfo]()
113	mcpBroker    = pubsub.NewBroker[MCPEvent]()
114)
115
116type McpTool struct {
117	mcpName     string
118	tool        mcp.Tool
119	permissions permission.Service
120	workingDir  string
121}
122
123func (b *McpTool) Name() string {
124	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
125}
126
127func (b *McpTool) Info() tools.ToolInfo {
128	required := b.tool.InputSchema.Required
129	if required == nil {
130		required = make([]string, 0)
131	}
132	parameters := b.tool.InputSchema.Properties
133	if parameters == nil {
134		parameters = make(map[string]any)
135	}
136	return tools.ToolInfo{
137		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
138		Description: b.tool.Description,
139		Parameters:  parameters,
140		Required:    required,
141	}
142}
143
144func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
145	var args map[string]any
146	if err := json.Unmarshal([]byte(input), &args); err != nil {
147		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
148	}
149
150	c, err := getOrRenewClient(ctx, name)
151	if err != nil {
152		return tools.NewTextErrorResponse(err.Error()), nil
153	}
154	result, err := c.CallTool(ctx, mcp.CallToolRequest{
155		Params: mcp.CallToolParams{
156			Name:      toolName,
157			Arguments: args,
158		},
159	})
160	if err != nil {
161		return tools.NewTextErrorResponse(err.Error()), nil
162	}
163
164	output := make([]string, 0, len(result.Content))
165	for _, v := range result.Content {
166		if v, ok := v.(mcp.TextContent); ok {
167			output = append(output, v.Text)
168		} else {
169			output = append(output, fmt.Sprintf("%v", v))
170		}
171	}
172	return tools.NewTextResponse(strings.Join(output, "\n")), nil
173}
174
175func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
176	c, ok := mcpClients.Get(name)
177	if !ok {
178		return nil, fmt.Errorf("mcp '%s' not available", name)
179	}
180
181	m := config.Get().MCP[name]
182	state, _ := mcpStates.Get(name)
183
184	timeout := mcpTimeout(m)
185	pingCtx, cancel := context.WithTimeout(ctx, timeout)
186	defer cancel()
187	err := c.Ping(pingCtx)
188	if err == nil {
189		return c, nil
190	}
191	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
192
193	c, err = createAndInitializeClient(ctx, name, m)
194	if err != nil {
195		return nil, err
196	}
197
198	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
199	mcpClients.Set(name, c)
200	return c, nil
201}
202
203func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
204	sessionID, messageID := tools.GetContextValues(ctx)
205	if sessionID == "" || messageID == "" {
206		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
207	}
208	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
209	p := b.permissions.Request(
210		permission.CreatePermissionRequest{
211			SessionID:   sessionID,
212			ToolCallID:  params.ID,
213			Path:        b.workingDir,
214			ToolName:    b.Info().Name,
215			Action:      "execute",
216			Description: permissionDescription,
217			Params:      params.Input,
218		},
219	)
220	if !p {
221		return tools.ToolResponse{}, permission.ErrorPermissionDenied
222	}
223
224	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
225}
226
227func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
228	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
229	if err != nil {
230		slog.Error("error listing tools", "error", err)
231		updateMCPState(name, MCPStateError, err, nil, 0)
232		c.Close()
233		mcpClients.Del(name)
234		return nil
235	}
236	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
237	for _, tool := range result.Tools {
238		mcpTools = append(mcpTools, &McpTool{
239			mcpName:     name,
240			tool:        tool,
241			permissions: permissions,
242			workingDir:  workingDir,
243		})
244	}
245	return mcpTools
246}
247
248// SubscribeMCPEvents returns a channel for MCP events
249func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
250	return mcpBroker.Subscribe(ctx)
251}
252
253// GetMCPStates returns the current state of all MCP clients
254func GetMCPStates() map[string]MCPClientInfo {
255	return maps.Collect(mcpStates.Seq2())
256}
257
258// GetMCPState returns the state of a specific MCP client
259func GetMCPState(name string) (MCPClientInfo, bool) {
260	return mcpStates.Get(name)
261}
262
263// updateMCPState updates the state of an MCP client and publishes an event
264func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
265	info := MCPClientInfo{
266		Name:      name,
267		State:     state,
268		Error:     err,
269		Client:    client,
270		ToolCount: toolCount,
271	}
272	if state == MCPStateConnected {
273		info.ConnectedAt = time.Now()
274	}
275	mcpStates.Set(name, info)
276
277	// Publish state change event
278	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
279		Type:      MCPEventStateChanged,
280		Name:      name,
281		State:     state,
282		Error:     err,
283		ToolCount: toolCount,
284	})
285}
286
287// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
288func CloseMCPClients() error {
289	var errs []error
290	for c := range mcpClients.Seq() {
291		if err := c.Close(); err != nil {
292			errs = append(errs, err)
293		}
294	}
295	mcpBroker.Shutdown()
296	return errors.Join(errs...)
297}
298
299var mcpInitRequest = mcp.InitializeRequest{
300	Params: mcp.InitializeParams{
301		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
302		ClientInfo: mcp.Implementation{
303			Name:    "Crush",
304			Version: version.Version,
305		},
306	},
307}
308
309func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
310	var wg sync.WaitGroup
311	result := csync.NewSlice[tools.BaseTool]()
312
313	// Initialize states for all configured MCPs
314	for name, m := range cfg.MCP {
315		if m.Disabled {
316			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
317			slog.Debug("skipping disabled mcp", "name", name)
318			continue
319		}
320
321		// Set initial starting state
322		updateMCPState(name, MCPStateStarting, nil, nil, 0)
323
324		wg.Add(1)
325		go func(name string, m config.MCPConfig) {
326			defer func() {
327				wg.Done()
328				if r := recover(); r != nil {
329					var err error
330					switch v := r.(type) {
331					case error:
332						err = v
333					case string:
334						err = fmt.Errorf("panic: %s", v)
335					default:
336						err = fmt.Errorf("panic: %v", v)
337					}
338					updateMCPState(name, MCPStateError, err, nil, 0)
339					slog.Error("panic in mcp client initialization", "error", err, "name", name)
340				}
341			}()
342
343			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
344			defer cancel()
345			c, err := createAndInitializeClient(ctx, name, m)
346			if err != nil {
347				return
348			}
349			mcpClients.Set(name, c)
350
351			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
352			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
353			result.Append(tools...)
354		}(name, m)
355	}
356	wg.Wait()
357	return slices.Collect(result.Seq())
358}
359
360func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
361	c, err := createMcpClient(name, m)
362	if err != nil {
363		updateMCPState(name, MCPStateError, err, nil, 0)
364		slog.Error("error creating mcp client", "error", err, "name", name)
365		return nil, err
366	}
367
368	timeout := mcpTimeout(m)
369	initCtx, cancel := context.WithTimeout(ctx, timeout)
370	defer cancel()
371
372	// Only call Start() for non-stdio clients, as stdio clients auto-start
373	if m.Type != config.MCPStdio {
374		if err := c.Start(initCtx); err != nil {
375			updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
376			slog.Error("error starting mcp client", "error", err, "name", name)
377			_ = c.Close()
378			return nil, err
379		}
380	}
381	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
382		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
383		slog.Error("error initializing mcp client", "error", err, "name", name)
384		_ = c.Close()
385		return nil, err
386	}
387
388	slog.Info("Initialized mcp client", "name", name)
389	return c, nil
390}
391
392func maybeTimeoutErr(err error, timeout time.Duration) error {
393	if errors.Is(err, context.DeadlineExceeded) {
394		return fmt.Errorf("timed out after %s", timeout)
395	}
396	return err
397}
398
399func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
400	switch m.Type {
401	case config.MCPStdio:
402		if strings.TrimSpace(m.Command) == "" {
403			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
404		}
405		return client.NewStdioMCPClientWithOptions(
406			home.Long(m.Command),
407			m.ResolvedEnv(),
408			m.Args,
409			transport.WithCommandLogger(mcpLogger{name: name}),
410		)
411	case config.MCPHttp:
412		if strings.TrimSpace(m.URL) == "" {
413			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
414		}
415		return client.NewStreamableHttpClient(
416			m.URL,
417			transport.WithHTTPHeaders(m.ResolvedHeaders()),
418			transport.WithHTTPLogger(mcpLogger{name: name}),
419		)
420	case config.MCPSse:
421		if strings.TrimSpace(m.URL) == "" {
422			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
423		}
424		return client.NewSSEMCPClient(
425			m.URL,
426			client.WithHeaders(m.ResolvedHeaders()),
427			transport.WithSSELogger(mcpLogger{name: name}),
428		)
429	default:
430		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
431	}
432}
433
434// for MCP's clients.
435type mcpLogger struct{ name string }
436
437func (l mcpLogger) Errorf(format string, v ...any) {
438	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
439}
440
441func (l mcpLogger) Infof(format string, v ...any) {
442	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
443}
444
445func mcpTimeout(m config.MCPConfig) time.Duration {
446	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
447}