Detailed changes
@@ -18,6 +18,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/agent/prompt"
"github.com/charmbracelet/crush/internal/agent/tools"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/history"
@@ -344,25 +345,23 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
}
}
- mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
-
- for _, mcpTool := range mcpTools {
+ for tool := range mcp.GetMCPTools() {
if agent.AllowedMCP == nil {
// No MCP restrictions
- filteredTools = append(filteredTools, mcpTool)
+ filteredTools = append(filteredTools, tool)
} else if len(agent.AllowedMCP) == 0 {
// no mcps allowed
break
}
for mcp, tools := range agent.AllowedMCP {
- if mcp == mcpTool.MCP() {
+ if mcp == tool.MCP() {
if len(tools) == 0 {
- filteredTools = append(filteredTools, mcpTool)
+ filteredTools = append(filteredTools, tool)
}
for _, t := range tools {
- if t == mcpTool.MCPToolName() {
- filteredTools = append(filteredTools, mcpTool)
+ if t == tool.MCPToolName() {
+ filteredTools = append(filteredTools, tool)
}
}
break
@@ -1,525 +0,0 @@
-package tools
-
-import (
- "cmp"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "maps"
- "net/http"
- "os"
- "os/exec"
- "slices"
- "strings"
- "sync"
- "time"
-
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/home"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/version"
- "github.com/modelcontextprotocol/go-sdk/mcp"
-)
-
-// MCPState represents the current state of an MCP client
-type MCPState int
-
-const (
- MCPStateDisabled MCPState = iota
- MCPStateStarting
- MCPStateConnected
- MCPStateError
-)
-
-func (s MCPState) String() string {
- switch s {
- case MCPStateDisabled:
- return "disabled"
- case MCPStateStarting:
- return "starting"
- case MCPStateConnected:
- return "connected"
- case MCPStateError:
- return "error"
- default:
- return "unknown"
- }
-}
-
-// MCPEventType represents the type of MCP event
-type MCPEventType string
-
-const (
- MCPEventStateChanged MCPEventType = "state_changed"
- MCPEventToolsListChanged MCPEventType = "tools_list_changed"
-)
-
-// MCPEvent represents an event in the MCP system
-type MCPEvent struct {
- Type MCPEventType
- Name string
- State MCPState
- Error error
- ToolCount int
-}
-
-// MCPClientInfo holds information about an MCP client's state
-type MCPClientInfo struct {
- Name string
- State MCPState
- Error error
- Client *mcp.ClientSession
- ToolCount int
- ConnectedAt time.Time
-}
-
-var (
- mcpToolsOnce sync.Once
- mcpTools = csync.NewMap[string, *McpTool]()
- mcpClient2Tools = csync.NewMap[string, []*McpTool]()
- mcpClients = csync.NewMap[string, *mcp.ClientSession]()
- mcpStates = csync.NewMap[string, MCPClientInfo]()
- mcpBroker = pubsub.NewBroker[MCPEvent]()
-)
-
-type McpTool struct {
- mcpName string
- tool *mcp.Tool
- permissions permission.Service
- workingDir string
- providerOptions fantasy.ProviderOptions
-}
-
-func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) {
- m.providerOptions = opts
-}
-
-func (m *McpTool) ProviderOptions() fantasy.ProviderOptions {
- return m.providerOptions
-}
-
-func (m *McpTool) Name() string {
- return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
-}
-
-func (m *McpTool) MCP() string {
- return m.mcpName
-}
-
-func (m *McpTool) MCPToolName() string {
- return m.tool.Name
-}
-
-func (b *McpTool) Info() fantasy.ToolInfo {
- parameters := make(map[string]any)
- required := make([]string, 0)
-
- if input, ok := b.tool.InputSchema.(map[string]any); ok {
- if props, ok := input["properties"].(map[string]any); ok {
- parameters = props
- }
- if req, ok := input["required"].([]any); ok {
- // Convert []any -> []string when elements are strings
- for _, v := range req {
- if s, ok := v.(string); ok {
- required = append(required, s)
- }
- }
- } else if reqStr, ok := input["required"].([]string); ok {
- // Handle case where it's already []string
- required = reqStr
- }
- }
-
- return fantasy.ToolInfo{
- Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
- Description: b.tool.Description,
- Parameters: parameters,
- Required: required,
- }
-}
-
-func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
- var args map[string]any
- if err := json.Unmarshal([]byte(input), &args); err != nil {
- return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
- }
-
- c, err := getOrRenewClient(ctx, name)
- if err != nil {
- return fantasy.NewTextErrorResponse(err.Error()), nil
- }
- result, err := c.CallTool(ctx, &mcp.CallToolParams{
- Name: toolName,
- Arguments: args,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse(err.Error()), nil
- }
-
- output := make([]string, 0, len(result.Content))
- for _, v := range result.Content {
- if vv, ok := v.(*mcp.TextContent); ok {
- output = append(output, vv.Text)
- } else {
- output = append(output, fmt.Sprintf("%v", v))
- }
- }
- return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
-}
-
-func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
- sess, ok := mcpClients.Get(name)
- if !ok {
- return nil, fmt.Errorf("mcp '%s' not available", name)
- }
-
- cfg := config.Get()
- m := cfg.MCP[name]
- state, _ := mcpStates.Get(name)
-
- timeout := mcpTimeout(m)
- pingCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
- err := sess.Ping(pingCtx, nil)
- if err == nil {
- return sess, nil
- }
- updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
-
- sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
- if err != nil {
- return nil, err
- }
-
- updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
- mcpClients.Set(name, sess)
- return sess, nil
-}
-
-func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
- sessionID := GetSessionFromContext(ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
- }
- permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
- p := m.permissions.Request(
- permission.CreatePermissionRequest{
- SessionID: sessionID,
- ToolCallID: params.ID,
- Path: m.workingDir,
- ToolName: m.Info().Name,
- Action: "execute",
- Description: permissionDescription,
- Params: params.Input,
- },
- )
- if !p {
- return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
- }
-
- return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
-}
-
-func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
- result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
- if err != nil {
- return nil, err
- }
- mcpTools := make([]*McpTool, 0, len(result.Tools))
- for _, tool := range result.Tools {
- mcpTools = append(mcpTools, &McpTool{
- mcpName: name,
- tool: tool,
- permissions: permissions,
- workingDir: workingDir,
- })
- }
- return mcpTools, nil
-}
-
-// SubscribeMCPEvents returns a channel for MCP events
-func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
- return mcpBroker.Subscribe(ctx)
-}
-
-// GetMCPStates returns the current state of all MCP clients
-func GetMCPStates() map[string]MCPClientInfo {
- return maps.Collect(mcpStates.Seq2())
-}
-
-// GetMCPState returns the state of a specific MCP client
-func GetMCPState(name string) (MCPClientInfo, bool) {
- return mcpStates.Get(name)
-}
-
-// updateMCPState updates the state of an MCP client and publishes an event
-func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
- info := MCPClientInfo{
- Name: name,
- State: state,
- Error: err,
- Client: client,
- ToolCount: toolCount,
- }
- switch state {
- case MCPStateConnected:
- info.ConnectedAt = time.Now()
- case MCPStateError:
- updateMcpTools(name, nil)
- mcpClients.Del(name)
- }
- mcpStates.Set(name, info)
-
- // Publish state change event
- mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
- Type: MCPEventStateChanged,
- Name: name,
- State: state,
- Error: err,
- ToolCount: toolCount,
- })
-}
-
-// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
-func CloseMCPClients() error {
- var errs []error
- for name, c := range mcpClients.Seq2() {
- if err := c.Close(); err != nil &&
- !errors.Is(err, io.EOF) &&
- !errors.Is(err, context.Canceled) &&
- err.Error() != "signal: killed" {
- errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
- }
- }
- mcpBroker.Shutdown()
- return errors.Join(errs...)
-}
-
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
- mcpToolsOnce.Do(func() {
- var wg sync.WaitGroup
- // Initialize states for all configured MCPs
- for name, m := range cfg.MCP {
- if m.Disabled {
- updateMCPState(name, MCPStateDisabled, nil, nil, 0)
- slog.Debug("skipping disabled mcp", "name", name)
- continue
- }
-
- // Set initial starting state
- updateMCPState(name, MCPStateStarting, nil, nil, 0)
-
- wg.Add(1)
- go func(name string, m config.MCPConfig) {
- defer func() {
- wg.Done()
- if r := recover(); r != nil {
- var err error
- switch v := r.(type) {
- case error:
- err = v
- case string:
- err = fmt.Errorf("panic: %s", v)
- default:
- err = fmt.Errorf("panic: %v", v)
- }
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("panic in mcp client initialization", "error", err, "name", name)
- }
- }()
-
- ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
- defer cancel()
-
- c, err := createMCPSession(ctx, name, m, cfg.Resolver())
- if err != nil {
- return
- }
-
- mcpClients.Set(name, c)
-
- tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
- if err != nil {
- slog.Error("error listing tools", "error", err)
- updateMCPState(name, MCPStateError, err, nil, 0)
- c.Close()
- return
- }
-
- updateMcpTools(name, tools)
- mcpClients.Set(name, c)
- updateMCPState(name, MCPStateConnected, nil, c, len(tools))
- }(name, m)
- }
- wg.Wait()
- })
- return slices.Collect(mcpTools.Seq())
-}
-
-// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
-func updateMcpTools(mcpName string, tools []*McpTool) {
- if len(tools) == 0 {
- mcpClient2Tools.Del(mcpName)
- } else {
- mcpClient2Tools.Set(mcpName, tools)
- }
- for _, tools := range mcpClient2Tools.Seq2() {
- for _, t := range tools {
- mcpTools.Set(t.Name(), t)
- }
- }
-}
-
-func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
- timeout := mcpTimeout(m)
- mcpCtx, cancel := context.WithCancel(ctx)
- cancelTimer := time.AfterFunc(timeout, cancel)
-
- transport, err := createMCPTransport(mcpCtx, m, resolver)
- if err != nil {
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("error creating mcp client", "error", err, "name", name)
- cancel()
- cancelTimer.Stop()
- return nil, err
- }
-
- client := mcp.NewClient(
- &mcp.Implementation{
- Name: "crush",
- Version: version.Version,
- Title: "Crush",
- },
- &mcp.ClientOptions{
- ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
- mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
- Type: MCPEventToolsListChanged,
- Name: name,
- })
- },
- KeepAlive: time.Minute * 10,
- },
- )
-
- session, err := client.Connect(mcpCtx, transport, nil)
- if err != nil {
- err = maybeStdioErr(err, transport)
- updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
- slog.Error("error starting mcp client", "error", err, "name", name)
- cancel()
- cancelTimer.Stop()
- return nil, err
- }
-
- cancelTimer.Stop()
- slog.Info("Initialized mcp client", "name", name)
- return session, nil
-}
-
-// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
-// to parse, and the cli will then close it, causing the EOF error.
-// so, if we got an EOF err, and the transport is STDIO, we try to exec it
-// again with a timeout and collect the output so we can add details to the
-// error.
-// this happens particularly when starting things with npx, e.g. if node can't
-// be found or some other error like that.
-func maybeStdioErr(err error, transport mcp.Transport) error {
- if !errors.Is(err, io.EOF) {
- return err
- }
- ct, ok := transport.(*mcp.CommandTransport)
- if !ok {
- return err
- }
- if err2 := stdioMCPCheck(ct.Command); err2 != nil {
- err = errors.Join(err, err2)
- }
- return err
-}
-
-func maybeTimeoutErr(err error, timeout time.Duration) error {
- if errors.Is(err, context.Canceled) {
- return fmt.Errorf("timed out after %s", timeout)
- }
- return err
-}
-
-func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
- switch m.Type {
- case config.MCPStdio:
- command, err := resolver.ResolveValue(m.Command)
- if err != nil {
- return nil, fmt.Errorf("invalid mcp command: %w", err)
- }
- if strings.TrimSpace(command) == "" {
- return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
- }
- cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
- cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
- return &mcp.CommandTransport{
- Command: cmd,
- }, nil
- case config.MCPHttp:
- if strings.TrimSpace(m.URL) == "" {
- return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
- }
- client := &http.Client{
- Transport: &headerRoundTripper{
- headers: m.ResolvedHeaders(),
- },
- }
- return &mcp.StreamableClientTransport{
- Endpoint: m.URL,
- HTTPClient: client,
- }, nil
- case config.MCPSSE:
- if strings.TrimSpace(m.URL) == "" {
- return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
- }
- client := &http.Client{
- Transport: &headerRoundTripper{
- headers: m.ResolvedHeaders(),
- },
- }
- return &mcp.SSEClientTransport{
- Endpoint: m.URL,
- HTTPClient: client,
- }, nil
- default:
- return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
- }
-}
-
-type headerRoundTripper struct {
- headers map[string]string
-}
-
-func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- for k, v := range rt.headers {
- req.Header.Set(k, v)
- }
- return http.DefaultTransport.RoundTrip(req)
-}
-
-func mcpTimeout(m config.MCPConfig) time.Duration {
- return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
-}
-
-func stdioMCPCheck(old *exec.Cmd) error {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
- defer cancel()
- cmd := exec.CommandContext(ctx, old.Path, old.Args...)
- cmd.Env = old.Env
- out, err := cmd.CombinedOutput()
- if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
- return nil
- }
- return fmt.Errorf("%w: %s", err, string(out))
-}
@@ -0,0 +1,400 @@
+package mcp
+
+import (
+ "cmp"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "maps"
+ "net/http"
+ "os"
+ "os/exec"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/home"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/version"
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+var (
+ sessions = csync.NewMap[string, *mcp.ClientSession]()
+ states = csync.NewMap[string, ClientInfo]()
+ broker = pubsub.NewBroker[Event]()
+)
+
+// State represents the current state of an MCP client
+type State int
+
+const (
+ StateDisabled State = iota
+ StateStarting
+ StateConnected
+ StateError
+)
+
+func (s State) String() string {
+ switch s {
+ case StateDisabled:
+ return "disabled"
+ case StateStarting:
+ return "starting"
+ case StateConnected:
+ return "connected"
+ case StateError:
+ return "error"
+ default:
+ return "unknown"
+ }
+}
+
+// EventType represents the type of MCP event
+type EventType string
+
+const (
+ EventStateChanged EventType = "state_changed"
+ EventToolsListChanged EventType = "tools_list_changed"
+ EventPromptsListChanged EventType = "prompts_list_changed"
+)
+
+// Event represents an event in the MCP system
+type Event struct {
+ Type EventType
+ Name string
+ State State
+ Error error
+ Counts Counts
+}
+
+// Counts number of available tools, prompts, etc.
+type Counts struct {
+ Tools int
+ Prompts int
+}
+
+// ClientInfo holds information about an MCP client's state
+type ClientInfo struct {
+ Name string
+ State State
+ Error error
+ Client *mcp.ClientSession
+ Counts Counts
+ ConnectedAt time.Time
+}
+
+// SubscribeEvents returns a channel for MCP events
+func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
+ return broker.Subscribe(ctx)
+}
+
+// GetStates returns the current state of all MCP clients
+func GetStates() map[string]ClientInfo {
+ return maps.Collect(states.Seq2())
+}
+
+// GetState returns the state of a specific MCP client
+func GetState(name string) (ClientInfo, bool) {
+ return states.Get(name)
+}
+
+// Close closes all MCP clients. This should be called during application shutdown.
+func Close() error {
+ var errs []error
+ for name, c := range sessions.Seq2() {
+ if err := c.Close(); err != nil &&
+ !errors.Is(err, io.EOF) &&
+ !errors.Is(err, context.Canceled) &&
+ err.Error() != "signal: killed" {
+ errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
+ }
+ }
+ broker.Shutdown()
+ return errors.Join(errs...)
+}
+
+// Initialize initializes MCP clients based on the provided configuration.
+func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
+ var wg sync.WaitGroup
+ // Initialize states for all configured MCPs
+ for name, m := range cfg.MCP {
+ if m.Disabled {
+ updateState(name, StateDisabled, nil, nil, Counts{})
+ slog.Debug("skipping disabled mcp", "name", name)
+ continue
+ }
+
+ // Set initial starting state
+ updateState(name, StateStarting, nil, nil, Counts{})
+
+ wg.Add(1)
+ go func(name string, m config.MCPConfig) {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ var err error
+ switch v := r.(type) {
+ case error:
+ err = v
+ case string:
+ err = fmt.Errorf("panic: %s", v)
+ default:
+ err = fmt.Errorf("panic: %v", v)
+ }
+ updateState(name, StateError, err, nil, Counts{})
+ slog.Error("panic in mcp client initialization", "error", err, "name", name)
+ }
+ }()
+
+ ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+ defer cancel()
+
+ session, err := createSession(ctx, name, m, cfg.Resolver())
+ if err != nil {
+ return
+ }
+
+ tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir())
+ if err != nil {
+ slog.Error("error listing tools", "error", err)
+ updateState(name, StateError, err, nil, Counts{})
+ session.Close()
+ return
+ }
+
+ prompts, err := getPrompts(ctx, session)
+ if err != nil {
+ slog.Error("error listing prompts", "error", err)
+ updateState(name, StateError, err, nil, Counts{})
+ session.Close()
+ return
+ }
+
+ updateTools(name, tools)
+ updatePrompts(name, prompts)
+ sessions.Set(name, session)
+
+ updateState(name, StateConnected, nil, session, Counts{
+ Tools: len(tools),
+ Prompts: len(prompts),
+ })
+ }(name, m)
+ }
+ wg.Wait()
+}
+
+func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
+ sess, ok := sessions.Get(name)
+ if !ok {
+ return nil, fmt.Errorf("mcp '%s' not available", name)
+ }
+
+ cfg := config.Get()
+ m := cfg.MCP[name]
+ state, _ := states.Get(name)
+
+ timeout := mcpTimeout(m)
+ pingCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ err := sess.Ping(pingCtx, nil)
+ if err == nil {
+ return sess, nil
+ }
+ updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
+
+ sess, err = createSession(ctx, name, m, cfg.Resolver())
+ if err != nil {
+ return nil, err
+ }
+
+ updateState(name, StateConnected, nil, sess, state.Counts)
+ sessions.Set(name, sess)
+ return sess, nil
+}
+
+// updateState updates the state of an MCP client and publishes an event
+func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
+ info := ClientInfo{
+ Name: name,
+ State: state,
+ Error: err,
+ Client: client,
+ Counts: counts,
+ }
+ switch state {
+ case StateConnected:
+ info.ConnectedAt = time.Now()
+ case StateError:
+ updateTools(name, nil)
+ sessions.Del(name)
+ }
+ states.Set(name, info)
+
+ // Publish state change event
+ broker.Publish(pubsub.UpdatedEvent, Event{
+ Type: EventStateChanged,
+ Name: name,
+ State: state,
+ Error: err,
+ Counts: counts,
+ })
+}
+
+func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
+ timeout := mcpTimeout(m)
+ mcpCtx, cancel := context.WithCancel(ctx)
+ cancelTimer := time.AfterFunc(timeout, cancel)
+
+ transport, err := createTransport(mcpCtx, m, resolver)
+ if err != nil {
+ updateState(name, StateError, err, nil, Counts{})
+ slog.Error("error creating mcp client", "error", err, "name", name)
+ cancel()
+ cancelTimer.Stop()
+ return nil, err
+ }
+
+ client := mcp.NewClient(
+ &mcp.Implementation{
+ Name: "crush",
+ Version: version.Version,
+ Title: "Crush",
+ },
+ &mcp.ClientOptions{
+ ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
+ broker.Publish(pubsub.UpdatedEvent, Event{
+ Type: EventToolsListChanged,
+ Name: name,
+ })
+ },
+ PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
+ broker.Publish(pubsub.UpdatedEvent, Event{
+ Type: EventPromptsListChanged,
+ Name: name,
+ })
+ },
+ KeepAlive: time.Minute * 10,
+ },
+ )
+
+ session, err := client.Connect(mcpCtx, transport, nil)
+ if err != nil {
+ err = maybeStdioErr(err, transport)
+ updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
+ slog.Error("error starting mcp client", "error", err, "name", name)
+ cancel()
+ cancelTimer.Stop()
+ return nil, err
+ }
+
+ cancelTimer.Stop()
+ slog.Info("Initialized mcp client", "name", name)
+ return session, nil
+}
+
+// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
+// to parse, and the cli will then close it, causing the EOF error.
+// so, if we got an EOF err, and the transport is STDIO, we try to exec it
+// again with a timeout and collect the output so we can add details to the
+// error.
+// this happens particularly when starting things with npx, e.g. if node can't
+// be found or some other error like that.
+func maybeStdioErr(err error, transport mcp.Transport) error {
+ if !errors.Is(err, io.EOF) {
+ return err
+ }
+ ct, ok := transport.(*mcp.CommandTransport)
+ if !ok {
+ return err
+ }
+ if err2 := stdioCheck(ct.Command); err2 != nil {
+ err = errors.Join(err, err2)
+ }
+ return err
+}
+
+func maybeTimeoutErr(err error, timeout time.Duration) error {
+ if errors.Is(err, context.Canceled) {
+ return fmt.Errorf("timed out after %s", timeout)
+ }
+ return err
+}
+
+func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
+ switch m.Type {
+ case config.MCPStdio:
+ command, err := resolver.ResolveValue(m.Command)
+ if err != nil {
+ return nil, fmt.Errorf("invalid mcp command: %w", err)
+ }
+ if strings.TrimSpace(command) == "" {
+ return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
+ }
+ cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
+ cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
+ return &mcp.CommandTransport{
+ Command: cmd,
+ }, nil
+ case config.MCPHttp:
+ if strings.TrimSpace(m.URL) == "" {
+ return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
+ }
+ client := &http.Client{
+ Transport: &headerRoundTripper{
+ headers: m.ResolvedHeaders(),
+ },
+ }
+ return &mcp.StreamableClientTransport{
+ Endpoint: m.URL,
+ HTTPClient: client,
+ }, nil
+ case config.MCPSSE:
+ if strings.TrimSpace(m.URL) == "" {
+ return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
+ }
+ client := &http.Client{
+ Transport: &headerRoundTripper{
+ headers: m.ResolvedHeaders(),
+ },
+ }
+ return &mcp.SSEClientTransport{
+ Endpoint: m.URL,
+ HTTPClient: client,
+ }, nil
+ default:
+ return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
+ }
+}
+
+type headerRoundTripper struct {
+ headers map[string]string
+}
+
+func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ for k, v := range rt.headers {
+ req.Header.Set(k, v)
+ }
+ return http.DefaultTransport.RoundTrip(req)
+}
+
+func mcpTimeout(m config.MCPConfig) time.Duration {
+ return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
+}
+
+func stdioCheck(old *exec.Cmd) error {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+ cmd := exec.CommandContext(ctx, old.Path, old.Args...)
+ cmd.Env = old.Env
+ out, err := cmd.CombinedOutput()
+ if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
+ return nil
+ }
+ return fmt.Errorf("%w: %s", err, string(out))
+}
@@ -0,0 +1,83 @@
+package mcp
+
+import (
+ "context"
+ "iter"
+
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+type Prompt = mcp.Prompt
+
+var (
+ allPrompts = csync.NewMap[string, *Prompt]()
+ client2Prompts = csync.NewMap[string, []*Prompt]()
+)
+
+// GetPrompts returns all available MCP prompts.
+func GetPrompts() iter.Seq2[string, *Prompt] {
+ return allPrompts.Seq2()
+}
+
+// GetPrompt returns a specific MCP prompt by name.
+func GetPrompt(name string) (*Prompt, bool) {
+ return allPrompts.Get(name)
+}
+
+// GetPromptsByClient returns all prompts for a specific MCP client.
+func GetPromptsByClient(clientName string) ([]*Prompt, bool) {
+ return client2Prompts.Get(clientName)
+}
+
+// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
+func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
+ c, err := getOrRenewClient(ctx, clientName)
+ if err != nil {
+ return nil, err
+ }
+ result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{
+ Name: promptName,
+ Arguments: args,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ var messages []string
+ for _, msg := range result.Messages {
+ if msg.Role != "user" {
+ continue
+ }
+ if textContent, ok := msg.Content.(*mcp.TextContent); ok {
+ messages = append(messages, textContent.Text)
+ }
+ }
+ return messages, nil
+}
+
+func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
+ if c.InitializeResult().Capabilities.Prompts == nil {
+ return nil, nil
+ }
+ result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
+ if err != nil {
+ return nil, err
+ }
+ return result.Prompts, nil
+}
+
+// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps
+func updatePrompts(mcpName string, prompts []*Prompt) {
+ if len(prompts) == 0 {
+ client2Prompts.Del(mcpName)
+ } else {
+ client2Prompts.Set(mcpName, prompts)
+ }
+ for mcpName, prompts := range client2Prompts.Seq2() {
+ for _, p := range prompts {
+ key := mcpName + ":" + p.Name
+ allPrompts.Set(key, p)
+ }
+ }
+}
@@ -0,0 +1,169 @@
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "iter"
+ "strings"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent/tools"
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+var (
+ allTools = csync.NewMap[string, *Tool]()
+ client2Tools = csync.NewMap[string, []*Tool]()
+)
+
+// GetMCPTools returns all available MCP tools.
+func GetMCPTools() iter.Seq[*Tool] {
+ return allTools.Seq()
+}
+
+type Tool struct {
+ mcpName string
+ tool *mcp.Tool
+ permissions permission.Service
+ workingDir string
+ providerOptions fantasy.ProviderOptions
+}
+
+func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
+ m.providerOptions = opts
+}
+
+func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
+ return m.providerOptions
+}
+
+func (m *Tool) Name() string {
+ return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
+}
+
+func (m *Tool) MCP() string {
+ return m.mcpName
+}
+
+func (m *Tool) MCPToolName() string {
+ return m.tool.Name
+}
+
+func (m *Tool) Info() fantasy.ToolInfo {
+ parameters := make(map[string]any)
+ required := make([]string, 0)
+
+ if input, ok := m.tool.InputSchema.(map[string]any); ok {
+ if props, ok := input["properties"].(map[string]any); ok {
+ parameters = props
+ }
+ if req, ok := input["required"].([]any); ok {
+ // Convert []any -> []string when elements are strings
+ for _, v := range req {
+ if s, ok := v.(string); ok {
+ required = append(required, s)
+ }
+ }
+ } else if reqStr, ok := input["required"].([]string); ok {
+ // Handle case where it's already []string
+ required = reqStr
+ }
+ }
+
+ return fantasy.ToolInfo{
+ Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
+ Description: m.tool.Description,
+ Parameters: parameters,
+ Required: required,
+ }
+}
+
+func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
+ sessionID := tools.GetSessionFromContext(ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
+ }
+ permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
+ p := m.permissions.Request(
+ permission.CreatePermissionRequest{
+ SessionID: sessionID,
+ ToolCallID: params.ID,
+ Path: m.workingDir,
+ ToolName: m.Info().Name,
+ Action: "execute",
+ Description: permissionDescription,
+ Params: params.Input,
+ },
+ )
+ if !p {
+ return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
+ }
+
+ return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
+}
+
+func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
+ var args map[string]any
+ if err := json.Unmarshal([]byte(input), &args); err != nil {
+ return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+ }
+
+ c, err := getOrRenewClient(ctx, name)
+ if err != nil {
+ return fantasy.NewTextErrorResponse(err.Error()), nil
+ }
+ result, err := c.CallTool(ctx, &mcp.CallToolParams{
+ Name: toolName,
+ Arguments: args,
+ })
+ if err != nil {
+ return fantasy.NewTextErrorResponse(err.Error()), nil
+ }
+
+ output := make([]string, 0, len(result.Content))
+ for _, v := range result.Content {
+ if vv, ok := v.(*mcp.TextContent); ok {
+ output = append(output, vv.Text)
+ } else {
+ output = append(output, fmt.Sprintf("%v", v))
+ }
+ }
+ return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
+}
+
+func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) {
+ if c.InitializeResult().Capabilities.Tools == nil {
+ return nil, nil
+ }
+ result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
+ if err != nil {
+ return nil, err
+ }
+ mcpTools := make([]*Tool, 0, len(result.Tools))
+ for _, tool := range result.Tools {
+ mcpTools = append(mcpTools, &Tool{
+ mcpName: name,
+ tool: tool,
+ permissions: permissions,
+ workingDir: workingDir,
+ })
+ }
+ return mcpTools, nil
+}
+
+// updateTools updates the global mcpTools and mcpClient2Tools maps
+func updateTools(mcpName string, tools []*Tool) {
+ if len(tools) == 0 {
+ client2Tools.Del(mcpName)
+ } else {
+ client2Tools.Set(mcpName, tools)
+ }
+ for _, tools := range client2Tools.Seq2() {
+ for _, t := range tools {
+ allTools.Set(t.Name(), t)
+ }
+ }
+}
@@ -16,7 +16,7 @@ import (
"charm.land/fantasy"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/agent"
- "github.com/charmbracelet/crush/internal/agent/tools"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/db"
@@ -91,6 +91,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
// Initialize LSP clients in the background.
app.initLSPClients(ctx)
+ go func() {
+ slog.Info("Initializing MCP clients")
+ mcp.Initialize(ctx, app.Permissions, cfg)
+ }()
+
// cleanup database upon app shutdown
app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
@@ -260,7 +265,7 @@ func (app *App) setupEvents() {
setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "mcp", tools.SubscribeMCPEvents, app.events)
+ setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
cleanupFunc := func() error {
cancel()
@@ -324,7 +329,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
}
// Add MCP client cleanup to shutdown process
- app.cleanupFuncs = append(app.cleanupFuncs, tools.CloseMCPClients)
+ app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close)
return nil
}
@@ -1,8 +1,7 @@
package commands
import (
- "fmt"
- "strings"
+ "cmp"
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
@@ -20,9 +19,10 @@ const (
// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
type ShowArgumentsDialogMsg struct {
- CommandID string
- Content string
- ArgNames []string
+ CommandID string
+ Description string
+ ArgNames []string
+ OnSubmit func(args map[string]string) tea.Cmd
}
// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
@@ -39,26 +39,39 @@ type CommandArgumentsDialog interface {
}
type commandArgumentsDialogCmp struct {
- width int
- wWidth int // Width of the terminal window
- wHeight int // Height of the terminal window
-
- inputs []textinput.Model
- focusIndex int
- keys ArgumentsDialogKeyMap
- commandID string
- content string
- argNames []string
- help help.Model
+ wWidth, wHeight int
+ width, height int
+
+ inputs []textinput.Model
+ focused int
+ keys ArgumentsDialogKeyMap
+ arguments []Argument
+ help help.Model
+
+ id string
+ title string
+ name string
+ description string
+
+ onSubmit func(args map[string]string) tea.Cmd
+}
+
+type Argument struct {
+ Name, Title, Description string
+ Required bool
}
-func NewCommandArgumentsDialog(commandID, content string, argNames []string) CommandArgumentsDialog {
+func NewCommandArgumentsDialog(
+ id, title, name, description string,
+ arguments []Argument,
+ onSubmit func(args map[string]string) tea.Cmd,
+) CommandArgumentsDialog {
t := styles.CurrentTheme()
- inputs := make([]textinput.Model, len(argNames))
+ inputs := make([]textinput.Model, len(arguments))
- for i, name := range argNames {
+ for i, arg := range arguments {
ti := textinput.New()
- ti.Placeholder = fmt.Sprintf("Enter value for %s...", name)
+ ti.Placeholder = cmp.Or(arg.Description, "Enter value for "+arg.Title)
ti.SetWidth(40)
ti.SetVirtualCursor(false)
ti.Prompt = ""
@@ -75,14 +88,16 @@ func NewCommandArgumentsDialog(commandID, content string, argNames []string) Com
}
return &commandArgumentsDialogCmp{
- inputs: inputs,
- keys: DefaultArgumentsDialogKeyMap(),
- commandID: commandID,
- content: content,
- argNames: argNames,
- focusIndex: 0,
- width: 60,
- help: help.New(),
+ inputs: inputs,
+ keys: DefaultArgumentsDialogKeyMap(),
+ id: id,
+ name: name,
+ title: title,
+ description: description,
+ arguments: arguments,
+ width: 60,
+ help: help.New(),
+ onSubmit: onSubmit,
}
}
@@ -97,47 +112,51 @@ func (c *commandArgumentsDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
case tea.WindowSizeMsg:
c.wWidth = msg.Width
c.wHeight = msg.Height
+ c.width = min(90, c.wWidth)
+ c.height = min(15, c.wHeight)
+ for i := range c.inputs {
+ c.inputs[i].SetWidth(c.width - (paddingHorizontal * 2))
+ }
case tea.KeyPressMsg:
switch {
+ case key.Matches(msg, c.keys.Close):
+ return c, util.CmdHandler(dialogs.CloseDialogMsg{})
case key.Matches(msg, c.keys.Confirm):
- if c.focusIndex == len(c.inputs)-1 {
- content := c.content
- for i, name := range c.argNames {
+ if c.focused == len(c.inputs)-1 {
+ args := make(map[string]string)
+ for i, arg := range c.arguments {
value := c.inputs[i].Value()
- placeholder := "$" + name
- content = strings.ReplaceAll(content, placeholder, value)
+ args[arg.Name] = value
}
return c, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
- util.CmdHandler(CommandRunCustomMsg{
- Content: content,
- }),
+ c.onSubmit(args),
)
}
// Otherwise, move to the next input
- c.inputs[c.focusIndex].Blur()
- c.focusIndex++
- c.inputs[c.focusIndex].Focus()
+ c.inputs[c.focused].Blur()
+ c.focused++
+ c.inputs[c.focused].Focus()
case key.Matches(msg, c.keys.Next):
// Move to the next input
- c.inputs[c.focusIndex].Blur()
- c.focusIndex = (c.focusIndex + 1) % len(c.inputs)
- c.inputs[c.focusIndex].Focus()
+ c.inputs[c.focused].Blur()
+ c.focused = (c.focused + 1) % len(c.inputs)
+ c.inputs[c.focused].Focus()
case key.Matches(msg, c.keys.Previous):
// Move to the previous input
- c.inputs[c.focusIndex].Blur()
- c.focusIndex = (c.focusIndex - 1 + len(c.inputs)) % len(c.inputs)
- c.inputs[c.focusIndex].Focus()
+ c.inputs[c.focused].Blur()
+ c.focused = (c.focused - 1 + len(c.inputs)) % len(c.inputs)
+ c.inputs[c.focused].Focus()
case key.Matches(msg, c.keys.Close):
return c, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
var cmd tea.Cmd
- c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg)
+ c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg)
return c, cmd
}
case tea.PasteMsg:
var cmd tea.Cmd
- c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg)
+ c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg)
return c, cmd
}
return c, nil
@@ -152,26 +171,28 @@ func (c *commandArgumentsDialogCmp) View() string {
Foreground(t.Primary).
Bold(true).
Padding(0, 1).
- Render("Command Arguments")
+ Render(cmp.Or(c.title, c.name))
- explanation := t.S().Text.
+ promptName := t.S().Text.
Padding(0, 1).
- Render("This command requires arguments.")
+ Render(c.description)
- // Create input fields for each argument
inputFields := make([]string, len(c.inputs))
for i, input := range c.inputs {
- // Highlight the label of the focused input
- labelStyle := baseStyle.
- Padding(1, 1, 0, 1)
+ labelStyle := baseStyle.Padding(1, 1, 0, 1)
- if i == c.focusIndex {
+ if i == c.focused {
labelStyle = labelStyle.Foreground(t.FgBase).Bold(true)
} else {
labelStyle = labelStyle.Foreground(t.FgMuted)
}
- label := labelStyle.Render(c.argNames[i] + ":")
+ arg := c.arguments[i]
+ argName := cmp.Or(arg.Title, arg.Name)
+ if arg.Required {
+ argName += "*"
+ }
+ label := labelStyle.Render(argName + ":")
field := t.S().Text.
Padding(0, 1).
@@ -180,18 +201,14 @@ func (c *commandArgumentsDialogCmp) View() string {
inputFields[i] = lipgloss.JoinVertical(lipgloss.Left, label, field)
}
- // Join all elements vertically
- elements := []string{title, explanation}
+ elements := []string{title, promptName}
elements = append(elements, inputFields...)
c.help.ShowAll = false
helpText := baseStyle.Padding(0, 1).Render(c.help.View(c.keys))
elements = append(elements, "", helpText)
- content := lipgloss.JoinVertical(
- lipgloss.Left,
- elements...,
- )
+ content := lipgloss.JoinVertical(lipgloss.Left, elements...)
return baseStyle.Padding(1, 1, 0, 1).
Border(lipgloss.RoundedBorder()).
@@ -201,26 +218,33 @@ func (c *commandArgumentsDialogCmp) View() string {
}
func (c *commandArgumentsDialogCmp) Cursor() *tea.Cursor {
- cursor := c.inputs[c.focusIndex].Cursor()
+ if len(c.inputs) == 0 {
+ return nil
+ }
+ cursor := c.inputs[c.focused].Cursor()
if cursor != nil {
cursor = c.moveCursor(cursor)
}
return cursor
}
+const (
+ headerHeight = 3
+ itemHeight = 3
+ paddingHorizontal = 3
+)
+
func (c *commandArgumentsDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
row, col := c.Position()
- offset := row + 3 + (1+c.focusIndex)*3
+ offset := row + headerHeight + (1+c.focused)*itemHeight
cursor.Y += offset
- cursor.X = cursor.X + col + 3
+ cursor.X = cursor.X + col + paddingHorizontal
return cursor
}
func (c *commandArgumentsDialogCmp) Position() (int, int) {
- row := c.wHeight / 2
- row -= c.wHeight / 2
- col := c.wWidth / 2
- col -= c.width / 2
+ row := (c.wHeight / 2) - (c.height / 2)
+ col := (c.wWidth / 2) - (c.width / 2)
return row, col
}
@@ -2,6 +2,8 @@ package commands
import (
"os"
+ "slices"
+ "strings"
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
@@ -10,7 +12,10 @@ import (
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
@@ -25,9 +30,14 @@ const (
defaultWidth int = 70
)
+type commandType uint
+
+func (c commandType) String() string { return []string{"System", "User", "MCP"}[c] }
+
const (
- SystemCommands int = iota
+ SystemCommands commandType = iota
UserCommands
+ MCPPrompts
)
type listModel = list.FilterableList[list.CompletionItem[Command]]
@@ -54,9 +64,10 @@ type commandDialogCmp struct {
commandList listModel
keyMap CommandsDialogKeyMap
help help.Model
- commandType int // SystemCommands or UserCommands
- userCommands []Command // User-defined commands
- sessionID string // Current session ID
+ selected commandType // Selected SystemCommands, UserCommands, or MCPPrompts
+ userCommands []Command // User-defined commands
+ mcpPrompts *csync.Slice[Command] // MCP prompts
+ sessionID string // Current session ID
}
type (
@@ -102,8 +113,9 @@ func NewCommandDialog(sessionID string) CommandsDialog {
width: defaultWidth,
keyMap: DefaultCommandsDialogKeyMap(),
help: help,
- commandType: SystemCommands,
+ selected: SystemCommands,
sessionID: sessionID,
+ mcpPrompts: csync.NewSlice[Command](),
}
}
@@ -113,7 +125,8 @@ func (c *commandDialogCmp) Init() tea.Cmd {
return util.ReportError(err)
}
c.userCommands = commands
- return c.SetCommandType(c.commandType)
+ c.mcpPrompts.SetSlice(loadMCPPrompts())
+ return c.setCommandType(c.selected)
}
func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
@@ -122,9 +135,19 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
c.wWidth = msg.Width
c.wHeight = msg.Height
return c, tea.Batch(
- c.SetCommandType(c.commandType),
+ c.setCommandType(c.selected),
c.commandList.SetSize(c.listWidth(), c.listHeight()),
)
+ case pubsub.Event[mcp.Event]:
+ // Reload MCP prompts when MCP state changes
+ if msg.Type == pubsub.UpdatedEvent {
+ c.mcpPrompts.SetSlice(loadMCPPrompts())
+ // If we're currently viewing MCP prompts, refresh the list
+ if c.selected == MCPPrompts {
+ return c, c.setCommandType(MCPPrompts)
+ }
+ return c, nil
+ }
case tea.KeyPressMsg:
switch {
case key.Matches(msg, c.keyMap.Select):
@@ -138,15 +161,10 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
command.Handler(command),
)
case key.Matches(msg, c.keyMap.Tab):
- if len(c.userCommands) == 0 {
+ if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 {
return c, nil
}
- // Toggle command type between System and User commands
- if c.commandType == SystemCommands {
- return c, c.SetCommandType(UserCommands)
- } else {
- return c, c.SetCommandType(SystemCommands)
- }
+ return c, c.setCommandType(c.next())
case key.Matches(msg, c.keyMap.Close):
return c, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
@@ -158,13 +176,35 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return c, nil
}
+func (c *commandDialogCmp) next() commandType {
+ switch c.selected {
+ case SystemCommands:
+ if len(c.userCommands) > 0 {
+ return UserCommands
+ }
+ if c.mcpPrompts.Len() > 0 {
+ return MCPPrompts
+ }
+ fallthrough
+ case UserCommands:
+ if c.mcpPrompts.Len() > 0 {
+ return MCPPrompts
+ }
+ fallthrough
+ case MCPPrompts:
+ return SystemCommands
+ default:
+ return SystemCommands
+ }
+}
+
func (c *commandDialogCmp) View() string {
t := styles.CurrentTheme()
listView := c.commandList
radio := c.commandTypeRadio()
header := t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-lipgloss.Width(radio)-5) + " " + radio)
- if len(c.userCommands) == 0 {
+ if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 {
header = t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-4))
}
content := lipgloss.JoinVertical(
@@ -190,27 +230,41 @@ func (c *commandDialogCmp) Cursor() *tea.Cursor {
func (c *commandDialogCmp) commandTypeRadio() string {
t := styles.CurrentTheme()
- choices := []string{"System", "User"}
- iconSelected := "◉"
- iconUnselected := "○"
- if c.commandType == SystemCommands {
- return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
+
+ fn := func(i commandType) string {
+ if i == c.selected {
+ return "◉ " + i.String()
+ }
+ return "○ " + i.String()
+ }
+
+ parts := []string{
+ fn(SystemCommands),
+ }
+ if len(c.userCommands) > 0 {
+ parts = append(parts, fn(UserCommands))
+ }
+ if c.mcpPrompts.Len() > 0 {
+ parts = append(parts, fn(MCPPrompts))
}
- return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
+ return t.S().Base.Foreground(t.FgHalfMuted).Render(strings.Join(parts, " "))
}
func (c *commandDialogCmp) listWidth() int {
return defaultWidth - 2 // 4 for padding
}
-func (c *commandDialogCmp) SetCommandType(commandType int) tea.Cmd {
- c.commandType = commandType
+func (c *commandDialogCmp) setCommandType(commandType commandType) tea.Cmd {
+ c.selected = commandType
var commands []Command
- if c.commandType == SystemCommands {
+ switch c.selected {
+ case SystemCommands:
commands = c.defaultCommands()
- } else {
+ case UserCommands:
commands = c.userCommands
+ case MCPPrompts:
+ commands = slices.Collect(c.mcpPrompts.Seq())
}
commandItems := []list.CompletionItem[Command]{}
@@ -1,22 +1,26 @@
package commands
import (
+ "context"
"fmt"
"io/fs"
+ "log/slog"
"os"
"path/filepath"
"regexp"
"strings"
tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/home"
+ "github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/util"
)
const (
- UserCommandPrefix = "user:"
- ProjectCommandPrefix = "project:"
+ userCommandPrefix = "user:"
+ projectCommandPrefix = "project:"
)
var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
@@ -50,7 +54,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
if dir := getXDGCommandsDir(); dir != "" {
sources = append(sources, commandSource{
path: dir,
- prefix: UserCommandPrefix,
+ prefix: userCommandPrefix,
})
}
@@ -58,14 +62,14 @@ func buildCommandSources(cfg *config.Config) []commandSource {
if home := home.Dir(); home != "" {
sources = append(sources, commandSource{
path: filepath.Join(home, ".crush", "commands"),
- prefix: UserCommandPrefix,
+ prefix: userCommandPrefix,
})
}
// Project directory
sources = append(sources, commandSource{
path: filepath.Join(cfg.Options.DataDirectory, "commands"),
- prefix: ProjectCommandPrefix,
+ prefix: projectCommandPrefix,
})
return sources
@@ -127,12 +131,13 @@ func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, erro
}
id := buildCommandID(path, baseDir, prefix)
+ desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
return Command{
ID: id,
Title: id,
- Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
- Handler: createCommandHandler(id, string(content)),
+ Description: desc,
+ Handler: createCommandHandler(id, desc, string(content)),
}, nil
}
@@ -149,21 +154,35 @@ func buildCommandID(path, baseDir, prefix string) string {
return prefix + strings.Join(parts, ":")
}
-func createCommandHandler(id string, content string) func(Command) tea.Cmd {
+func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
return func(cmd Command) tea.Cmd {
args := extractArgNames(content)
- if len(args) > 0 {
- return util.CmdHandler(ShowArgumentsDialogMsg{
- CommandID: id,
- Content: content,
- ArgNames: args,
+ if len(args) == 0 {
+ return util.CmdHandler(CommandRunCustomMsg{
+ Content: content,
})
}
+ return util.CmdHandler(ShowArgumentsDialogMsg{
+ CommandID: id,
+ Description: desc,
+ ArgNames: args,
+ OnSubmit: func(args map[string]string) tea.Cmd {
+ return execUserPrompt(content, args)
+ },
+ })
+ }
+}
- return util.CmdHandler(CommandRunCustomMsg{
+func execUserPrompt(content string, args map[string]string) tea.Cmd {
+ return func() tea.Msg {
+ for name, value := range args {
+ placeholder := "$" + name
+ content = strings.ReplaceAll(content, placeholder, value)
+ }
+ return CommandRunCustomMsg{
Content: content,
- })
+ }
}
}
@@ -201,3 +220,55 @@ func isMarkdownFile(name string) bool {
type CommandRunCustomMsg struct {
Content string
}
+
+func loadMCPPrompts() []Command {
+ var commands []Command
+ for key, prompt := range mcp.GetPrompts() {
+ clientName, promptName, ok := strings.Cut(key, ":")
+ if !ok {
+ slog.Warn("prompt not found", "key", key)
+ continue
+ }
+ commands = append(commands, Command{
+ ID: key,
+ Title: clientName + ":" + promptName,
+ Description: prompt.Description,
+ Handler: createMCPPromptHandler(clientName, promptName, prompt),
+ })
+ }
+
+ return commands
+}
+
+func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
+ return func(cmd Command) tea.Cmd {
+ if len(prompt.Arguments) == 0 {
+ return execMCPPrompt(clientName, promptName, nil)
+ }
+ return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
+ Prompt: prompt,
+ OnSubmit: func(args map[string]string) tea.Cmd {
+ return execMCPPrompt(clientName, promptName, args)
+ },
+ })
+ }
+}
+
+func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
+ return func() tea.Msg {
+ ctx := context.Background()
+ result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
+ if err != nil {
+ return util.ReportError(err)
+ }
+
+ return chat.SendMsg{
+ Text: strings.Join(result, " "),
+ }
+ }
+}
+
+type ShowMCPPromptArgumentsDialogMsg struct {
+ Prompt *mcp.Prompt
+ OnSubmit func(arg map[string]string) tea.Cmd
+}
@@ -2,10 +2,11 @@ package mcp
import (
"fmt"
+ "strings"
"github.com/charmbracelet/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/agent/tools"
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/styles"
@@ -40,7 +41,7 @@ func RenderMCPList(opts RenderOptions) []string {
}
// Get MCP states
- mcpStates := tools.GetMCPStates()
+ mcpStates := mcp.GetStates()
// Determine how many items to show
maxItems := len(mcps)
@@ -56,21 +57,24 @@ func RenderMCPList(opts RenderOptions) []string {
// Determine icon and color based on state
icon := t.ItemOfflineIcon
description := ""
- extraContent := ""
+ extraContent := []string{}
if state, exists := mcpStates[l.Name]; exists {
switch state.State {
- case tools.MCPStateDisabled:
+ case mcp.StateDisabled:
description = t.S().Subtle.Render("disabled")
- case tools.MCPStateStarting:
+ case mcp.StateStarting:
icon = t.ItemBusyIcon
description = t.S().Subtle.Render("starting...")
- case tools.MCPStateConnected:
+ case mcp.StateConnected:
icon = t.ItemOnlineIcon
- if state.ToolCount > 0 {
- extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount))
+ if count := state.Counts.Tools; count > 0 {
+ extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count)))
}
- case tools.MCPStateError:
+ if count := state.Counts.Prompts; count > 0 {
+ extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count)))
+ }
+ case mcp.StateError:
icon = t.ItemErrorIcon
if state.Error != nil {
description = t.S().Subtle.Render(fmt.Sprintf("error: %s", state.Error.Error()))
@@ -88,7 +92,7 @@ func RenderMCPList(opts RenderOptions) []string {
Icon: icon.String(),
Title: l.Name,
Description: description,
- ExtraContent: extraContent,
+ ExtraContent: strings.Join(extraContent, " "),
},
opts.MaxWidth,
),
@@ -33,6 +33,8 @@ import (
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
"github.com/charmbracelet/lipgloss/v2"
+ "golang.org/x/text/cases"
+ "golang.org/x/text/language"
)
var lastMouseEvent time.Time
@@ -156,15 +158,44 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.dialog = u.(dialogs.DialogCmp)
return a, tea.Batch(completionCmd, dialogCmd)
case commands.ShowArgumentsDialogMsg:
+ var args []commands.Argument
+ for _, arg := range msg.ArgNames {
+ args = append(args, commands.Argument{
+ Name: arg,
+ Title: cases.Title(language.English).String(arg),
+ Required: true,
+ })
+ }
return a, util.CmdHandler(
dialogs.OpenDialogMsg{
Model: commands.NewCommandArgumentsDialog(
msg.CommandID,
- msg.Content,
- msg.ArgNames,
+ msg.CommandID,
+ msg.CommandID,
+ msg.Description,
+ args,
+ msg.OnSubmit,
),
},
)
+ case commands.ShowMCPPromptArgumentsDialogMsg:
+ args := make([]commands.Argument, 0, len(msg.Prompt.Arguments))
+ for _, arg := range msg.Prompt.Arguments {
+ args = append(args, commands.Argument(*arg))
+ }
+ dialog := commands.NewCommandArgumentsDialog(
+ msg.Prompt.Name,
+ msg.Prompt.Title,
+ msg.Prompt.Name,
+ msg.Prompt.Description,
+ args,
+ msg.OnSubmit,
+ )
+ return a, util.CmdHandler(
+ dialogs.OpenDialogMsg{
+ Model: dialog,
+ },
+ )
// Page change messages
case page.PageChangeMsg:
return a, a.moveToPage(msg.ID)