@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
+ "maps"
"slices"
"strings"
"sync"
@@ -110,9 +111,10 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
if err := json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
- c, ok := mcpClients.Get(name)
- if !ok {
- return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
+
+ c, err := getOrRenewClient(ctx, name)
+ if err != nil {
+ return tools.NewTextErrorResponse(err.Error()), nil
}
result, err := c.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
@@ -135,6 +137,33 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
return tools.NewTextResponse(strings.Join(output, "\n")), nil
}
+func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
+ c, ok := mcpClients.Get(name)
+ if !ok {
+ return nil, fmt.Errorf("mcp '%s' not available", name)
+ }
+
+ m := config.Get().MCP[name]
+ state, _ := mcpStates.Get(name)
+
+ pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+ defer cancel()
+ err := c.Ping(pingCtx)
+ if err == nil {
+ return c, nil
+ }
+ updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
+
+ c, err = createAndInitializeClient(ctx, name, m)
+ if err != nil {
+ return nil, err
+ }
+
+ updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
+ mcpClients.Set(name, c)
+ return c, nil
+}
+
func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
sessionID, messageID := tools.GetContextValues(ctx)
if sessionID == "" || messageID == "" {
@@ -187,11 +216,7 @@ func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
// GetMCPStates returns the current state of all MCP clients
func GetMCPStates() map[string]MCPClientInfo {
- states := make(map[string]MCPClientInfo)
- for name, info := range mcpStates.Seq2() {
- states[name] = info
- }
- return states
+ return maps.Collect(mcpStates.Seq2())
}
// GetMCPState returns the state of a specific MCP client
@@ -275,32 +300,12 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
}
}()
- timeout := time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
- ctx, cancel := context.WithTimeout(ctx, timeout)
+ ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
defer cancel()
- c, err := createMcpClient(m)
+ c, err := createAndInitializeClient(ctx, name, m)
if err != nil {
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("error creating mcp client", "error", err, "name", name)
return
}
- // Only call Start() for non-stdio clients, as stdio clients auto-start
- if m.Type != config.MCPStdio {
- if err := c.Start(ctx); err != nil {
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("error starting mcp client", "error", err, "name", name)
- _ = c.Close()
- return
- }
- }
- if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("error initializing mcp client", "error", err, "name", name)
- _ = c.Close()
- return
- }
-
- slog.Info("Initialized mcp client", "name", name)
mcpClients.Set(name, c)
tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
@@ -312,6 +317,33 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
return slices.Collect(result.Seq())
}
+func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
+ c, err := createMcpClient(m)
+ if err != nil {
+ updateMCPState(name, MCPStateError, err, nil, 0)
+ slog.Error("error creating mcp client", "error", err, "name", name)
+ return nil, err
+ }
+ // Only call Start() for non-stdio clients, as stdio clients auto-start
+ if m.Type != config.MCPStdio {
+ if err := c.Start(ctx); err != nil {
+ updateMCPState(name, MCPStateError, err, nil, 0)
+ slog.Error("error starting mcp client", "error", err, "name", name)
+ _ = c.Close()
+ return nil, err
+ }
+ }
+ if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
+ updateMCPState(name, MCPStateError, err, nil, 0)
+ slog.Error("error initializing mcp client", "error", err, "name", name)
+ _ = c.Close()
+ return nil, err
+ }
+
+ slog.Info("Initialized mcp client", "name", name)
+ return c, nil
+}
+
func createMcpClient(m config.MCPConfig) (*client.Client, error) {
switch m.Type {
case config.MCPStdio:
@@ -343,3 +375,7 @@ type mcpLogger struct{}
func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) }
+
+func mcpTimeout(m config.MCPConfig) time.Duration {
+ return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
+}