diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index ae81a306b7981713b9faefc6cde860b640a2b5cf..7ef24148c93c7c9d08156e25f59b55e1327c3534 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -335,7 +335,11 @@ func updateMcpTools(mcpName string, tools []tools.BaseTool) { } func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) { - transport, err := createMCPTransport(m, resolver) + timeout := mcpTimeout(m) + mcpCtx, cancel := context.WithCancel(ctx) + cancelTimer := time.AfterFunc(timeout, cancel) + + transport, err := createMCPTransport(mcpCtx, m, resolver) if err != nil { updateMCPState(name, MCPStateError, err, nil, 0) slog.Error("error creating mcp client", "error", err, "name", name) @@ -359,10 +363,6 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso }, ) - timeout := mcpTimeout(m) - mcpCtx, cancel := context.WithCancel(ctx) - cancelTimer := time.AfterFunc(timeout, cancel) - session, err := client.Connect(mcpCtx, transport, nil) if err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) @@ -384,7 +384,7 @@ func maybeTimeoutErr(err error, timeout time.Duration) error { return err } -func createMCPTransport(m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) { +func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) { switch m.Type { case config.MCPStdio: command, err := resolver.ResolveValue(m.Command) @@ -394,7 +394,7 @@ func createMCPTransport(m config.MCPConfig, resolver config.VariableResolver) (m if strings.TrimSpace(command) == "" { return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") } - cmd := exec.Command(home.Long(command), m.Args...) + cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) cmd.Env = m.ResolvedEnv() return &mcp.CommandTransport{ Command: cmd,