diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index a2e6b912ab503c61522501ad522a9f0a65fc37b0..f5125c8b89f2dda534396f3c51df3839390022ce 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -196,14 +196,10 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return runTool(ctx, b.mcpName, b.tool.Name, params.Input) } -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool { +func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) { result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - c.Close() - mcpClients.Del(name) - return nil + return nil, err } mcpTools := make([]tools.BaseTool, 0, len(result.Tools)) for _, tool := range result.Tools { @@ -214,7 +210,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service, workingDir: workingDir, }) } - return mcpTools + return mcpTools, nil } // SubscribeMCPEvents returns a channel for MCP events @@ -314,13 +310,21 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) defer cancel() + c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver()) if err != nil { return } - mcpClients.Set(name, c) - tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) + tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0) + c.Close() + return + } + + mcpClients.Set(name, c) updateMCPState(name, MCPStateConnected, nil, c, len(tools)) result.Append(tools...) }(name, m) @@ -341,14 +345,11 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon initCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - // Only call Start() for non-stdio clients, as stdio clients auto-start - if m.Type != config.MCPStdio { - if err := c.Start(initCtx); err != nil { - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - _ = c.Close() - return nil, err - } + if err := c.Start(initCtx); err != nil { + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) + slog.Error("error starting mcp client", "error", err, "name", name) + _ = c.Close() + return nil, err } if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)