feat(mcp): ping and recreate mcp client if needed (#772)

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/llm/agent/mcp-tools.go | 96 ++++++++++++++++++++++++----------
1 file changed, 66 insertions(+), 30 deletions(-)

Detailed changes

internal/llm/agent/mcp-tools.go 🔗

@@ -6,6 +6,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"log/slog"
+	"maps"
 	"slices"
 	"strings"
 	"sync"
@@ -110,9 +111,10 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
 	if err := json.Unmarshal([]byte(input), &args); err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 	}
-	c, ok := mcpClients.Get(name)
-	if !ok {
-		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
+
+	c, err := getOrRenewClient(ctx, name)
+	if err != nil {
+		return tools.NewTextErrorResponse(err.Error()), nil
 	}
 	result, err := c.CallTool(ctx, mcp.CallToolRequest{
 		Params: mcp.CallToolParams{
@@ -135,6 +137,33 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
 	return tools.NewTextResponse(strings.Join(output, "\n")), nil
 }
 
+func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
+	c, ok := mcpClients.Get(name)
+	if !ok {
+		return nil, fmt.Errorf("mcp '%s' not available", name)
+	}
+
+	m := config.Get().MCP[name]
+	state, _ := mcpStates.Get(name)
+
+	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+	defer cancel()
+	err := c.Ping(pingCtx)
+	if err == nil {
+		return c, nil
+	}
+	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
+
+	c, err = createAndInitializeClient(ctx, name, m)
+	if err != nil {
+		return nil, err
+	}
+
+	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
+	mcpClients.Set(name, c)
+	return c, nil
+}
+
 func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 	sessionID, messageID := tools.GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
@@ -187,11 +216,7 @@ func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
 
 // GetMCPStates returns the current state of all MCP clients
 func GetMCPStates() map[string]MCPClientInfo {
-	states := make(map[string]MCPClientInfo)
-	for name, info := range mcpStates.Seq2() {
-		states[name] = info
-	}
-	return states
+	return maps.Collect(mcpStates.Seq2())
 }
 
 // GetMCPState returns the state of a specific MCP client
@@ -275,32 +300,12 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 				}
 			}()
 
-			timeout := time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
-			ctx, cancel := context.WithTimeout(ctx, timeout)
+			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
 			defer cancel()
-			c, err := createMcpClient(m)
+			c, err := createAndInitializeClient(ctx, name, m)
 			if err != nil {
-				updateMCPState(name, MCPStateError, err, nil, 0)
-				slog.Error("error creating mcp client", "error", err, "name", name)
 				return
 			}
-			// Only call Start() for non-stdio clients, as stdio clients auto-start
-			if m.Type != config.MCPStdio {
-				if err := c.Start(ctx); err != nil {
-					updateMCPState(name, MCPStateError, err, nil, 0)
-					slog.Error("error starting mcp client", "error", err, "name", name)
-					_ = c.Close()
-					return
-				}
-			}
-			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
-				updateMCPState(name, MCPStateError, err, nil, 0)
-				slog.Error("error initializing mcp client", "error", err, "name", name)
-				_ = c.Close()
-				return
-			}
-
-			slog.Info("Initialized mcp client", "name", name)
 			mcpClients.Set(name, c)
 
 			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
@@ -312,6 +317,33 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 	return slices.Collect(result.Seq())
 }
 
+func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
+	c, err := createMcpClient(m)
+	if err != nil {
+		updateMCPState(name, MCPStateError, err, nil, 0)
+		slog.Error("error creating mcp client", "error", err, "name", name)
+		return nil, err
+	}
+	// Only call Start() for non-stdio clients, as stdio clients auto-start
+	if m.Type != config.MCPStdio {
+		if err := c.Start(ctx); err != nil {
+			updateMCPState(name, MCPStateError, err, nil, 0)
+			slog.Error("error starting mcp client", "error", err, "name", name)
+			_ = c.Close()
+			return nil, err
+		}
+	}
+	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
+		updateMCPState(name, MCPStateError, err, nil, 0)
+		slog.Error("error initializing mcp client", "error", err, "name", name)
+		_ = c.Close()
+		return nil, err
+	}
+
+	slog.Info("Initialized mcp client", "name", name)
+	return c, nil
+}
+
 func createMcpClient(m config.MCPConfig) (*client.Client, error) {
 	switch m.Type {
 	case config.MCPStdio:
@@ -343,3 +375,7 @@ type mcpLogger struct{}
 
 func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
 func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
+
+func mcpTimeout(m config.MCPConfig) time.Duration {
+	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
+}