diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index beb35085a5206ba55d81d927c3fb84d7120bdff1..a1a2134da517291423d056ae7674f28b200ac7a0 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log/slog" + "maps" "slices" "strings" "sync" @@ -110,9 +111,10 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To if err := json.Unmarshal([]byte(input), &args); err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } - c, ok := mcpClients.Get(name) - if !ok { - return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil + + c, err := getOrRenewClient(ctx, name) + if err != nil { + return tools.NewTextErrorResponse(err.Error()), nil } result, err := c.CallTool(ctx, mcp.CallToolRequest{ Params: mcp.CallToolParams{ @@ -135,6 +137,33 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To return tools.NewTextResponse(strings.Join(output, "\n")), nil } +func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) { + c, ok := mcpClients.Get(name) + if !ok { + return nil, fmt.Errorf("mcp '%s' not available", name) + } + + m := config.Get().MCP[name] + state, _ := mcpStates.Get(name) + + pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) + defer cancel() + err := c.Ping(pingCtx) + if err == nil { + return c, nil + } + updateMCPState(name, MCPStateError, err, nil, state.ToolCount) + + c, err = createAndInitializeClient(ctx, name, m) + if err != nil { + return nil, err + } + + updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount) + mcpClients.Set(name, c) + return c, nil +} + func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { sessionID, messageID := tools.GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -187,11 +216,7 @@ func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] { // GetMCPStates returns the current state of all MCP clients func GetMCPStates() map[string]MCPClientInfo { - states := make(map[string]MCPClientInfo) - for name, info := range mcpStates.Seq2() { - states[name] = info - } - return states + return maps.Collect(mcpStates.Seq2()) } // GetMCPState returns the state of a specific MCP client @@ -275,32 +300,12 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con } }() - timeout := time.Duration(cmp.Or(m.Timeout, 15)) * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) defer cancel() - c, err := createMcpClient(m) + c, err := createAndInitializeClient(ctx, name, m) if err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error creating mcp client", "error", err, "name", name) return } - // Only call Start() for non-stdio clients, as stdio clients auto-start - if m.Type != config.MCPStdio { - if err := c.Start(ctx); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - _ = c.Close() - return - } - } - if _, err := c.Initialize(ctx, mcpInitRequest); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error initializing mcp client", "error", err, "name", name) - _ = c.Close() - return - } - - slog.Info("Initialized mcp client", "name", name) mcpClients.Set(name, c) tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) @@ -312,6 +317,33 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return slices.Collect(result.Seq()) } +func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) { + c, err := createMcpClient(m) + if err != nil { + updateMCPState(name, MCPStateError, err, nil, 0) + slog.Error("error creating mcp client", "error", err, "name", name) + return nil, err + } + // Only call Start() for non-stdio clients, as stdio clients auto-start + if m.Type != config.MCPStdio { + if err := c.Start(ctx); err != nil { + updateMCPState(name, MCPStateError, err, nil, 0) + slog.Error("error starting mcp client", "error", err, "name", name) + _ = c.Close() + return nil, err + } + } + if _, err := c.Initialize(ctx, mcpInitRequest); err != nil { + updateMCPState(name, MCPStateError, err, nil, 0) + slog.Error("error initializing mcp client", "error", err, "name", name) + _ = c.Close() + return nil, err + } + + slog.Info("Initialized mcp client", "name", name) + return c, nil +} + func createMcpClient(m config.MCPConfig) (*client.Client, error) { switch m.Type { case config.MCPStdio: @@ -343,3 +375,7 @@ type mcpLogger struct{} func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) } func (l mcpLogger) Infof(format string, v ...any) { slog.Info(fmt.Sprintf(format, v...)) } + +func mcpTimeout(m config.MCPConfig) time.Duration { + return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second +}