diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 67d46b54b637af8d3dacef8149d34202157a9565..1043ca3b9820e72096a0aafe7cdb7868c8d29720 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -152,13 +152,14 @@ func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) m := config.Get().MCP[name] state, _ := mcpStates.Get(name) - pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) + timeout := mcpTimeout(m) + pingCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() err := c.Ping(pingCtx) if err == nil { return c, nil } - updateMCPState(name, MCPStateError, err, nil, state.ToolCount) + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) c, err = createAndInitializeClient(ctx, name, m) if err != nil { @@ -334,17 +335,22 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon slog.Error("error creating mcp client", "error", err, "name", name) return nil, err } + + timeout := mcpTimeout(m) + initCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + // Only call Start() for non-stdio clients, as stdio clients auto-start if m.Type != config.MCPStdio { - if err := c.Start(ctx); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) + if err := c.Start(initCtx); err != nil { + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error starting mcp client", "error", err, "name", name) _ = c.Close() return nil, err } } - if _, err := c.Initialize(ctx, mcpInitRequest); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) + if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil { + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error initializing mcp client", "error", err, "name", name) _ = c.Close() return nil, err @@ -354,6 +360,13 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon return c, nil } +func maybeTimeoutErr(err error, timeout time.Duration) error { + if errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("timed out after %s", timeout) + } + return err +} + func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) { switch m.Type { case config.MCPStdio: