http.rs

  1use anyhow::{Result, anyhow};
  2use async_trait::async_trait;
  3use collections::HashMap;
  4use futures::{Stream, StreamExt};
  5use gpui::BackgroundExecutor;
  6use http_client::{AsyncBody, HttpClient, Request, Response, http::Method};
  7use parking_lot::Mutex as SyncMutex;
  8use smol::channel;
  9use std::{pin::Pin, sync::Arc};
 10
 11use crate::transport::Transport;
 12
 13// Constants from MCP spec
 14const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
 15const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
 16const JSON_MIME_TYPE: &str = "application/json";
 17
 18/// HTTP Transport with session management and SSE support
 19pub struct HttpTransport {
 20    http_client: Arc<dyn HttpClient>,
 21    endpoint: String,
 22    session_id: Arc<SyncMutex<Option<String>>>,
 23    executor: BackgroundExecutor,
 24    response_tx: channel::Sender<String>,
 25    response_rx: channel::Receiver<String>,
 26    error_tx: channel::Sender<String>,
 27    error_rx: channel::Receiver<String>,
 28    // Authentication headers to include in requests
 29    headers: HashMap<String, String>,
 30}
 31
 32impl HttpTransport {
 33    pub fn new(
 34        http_client: Arc<dyn HttpClient>,
 35        endpoint: String,
 36        headers: HashMap<String, String>,
 37        executor: BackgroundExecutor,
 38    ) -> Self {
 39        let (response_tx, response_rx) = channel::unbounded();
 40        let (error_tx, error_rx) = channel::unbounded();
 41
 42        Self {
 43            http_client,
 44            executor,
 45            endpoint,
 46            session_id: Arc::new(SyncMutex::new(None)),
 47            response_tx,
 48            response_rx,
 49            error_tx,
 50            error_rx,
 51            headers,
 52        }
 53    }
 54
 55    /// Send a message and handle the response based on content type
 56    async fn send_message(&self, message: String) -> Result<()> {
 57        let is_notification =
 58            !message.contains("\"id\":") || message.contains("notifications/initialized");
 59
 60        let mut request_builder = Request::builder()
 61            .method(Method::POST)
 62            .uri(&self.endpoint)
 63            .header("Content-Type", JSON_MIME_TYPE)
 64            .header(
 65                "Accept",
 66                format!("{}, {}", JSON_MIME_TYPE, EVENT_STREAM_MIME_TYPE),
 67            );
 68
 69        for (key, value) in &self.headers {
 70            request_builder = request_builder.header(key.as_str(), value.as_str());
 71        }
 72
 73        // Add session ID if we have one (except for initialize)
 74        if let Some(ref session_id) = *self.session_id.lock() {
 75            request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
 76        }
 77
 78        let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
 79        let mut response = self.http_client.send(request).await?;
 80
 81        // Handle different response types based on status and content-type
 82        match response.status() {
 83            status if status.is_success() => {
 84                // Check content type
 85                let content_type = response
 86                    .headers()
 87                    .get("content-type")
 88                    .and_then(|v| v.to_str().ok());
 89
 90                // Extract session ID from response headers if present
 91                if let Some(session_id) = response
 92                    .headers()
 93                    .get(HEADER_SESSION_ID)
 94                    .and_then(|v| v.to_str().ok())
 95                {
 96                    *self.session_id.lock() = Some(session_id.to_string());
 97                    log::debug!("Session ID set: {}", session_id);
 98                }
 99
100                match content_type {
101                    Some(ct) if ct.starts_with(JSON_MIME_TYPE) => {
102                        // JSON response - read and forward immediately
103                        let mut body = String::new();
104                        futures::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
105                            .await?;
106
107                        // Only send non-empty responses
108                        if !body.is_empty() {
109                            self.response_tx
110                                .send(body)
111                                .await
112                                .map_err(|_| anyhow!("Failed to send JSON response"))?;
113                        }
114                    }
115                    Some(ct) if ct.starts_with(EVENT_STREAM_MIME_TYPE) => {
116                        // SSE stream - set up streaming
117                        self.setup_sse_stream(response).await?;
118                    }
119                    _ => {
120                        // For notifications, 202 Accepted with no content type is ok
121                        if is_notification && status.as_u16() == 202 {
122                            log::debug!("Notification accepted");
123                        } else {
124                            return Err(anyhow!("Unexpected content type: {:?}", content_type));
125                        }
126                    }
127                }
128            }
129            status if status.as_u16() == 202 => {
130                // Accepted - notification acknowledged, no response needed
131                log::debug!("Notification accepted");
132            }
133            _ => {
134                let mut error_body = String::new();
135                futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;
136
137                self.error_tx
138                    .send(format!("HTTP {}: {}", response.status(), error_body))
139                    .await
140                    .map_err(|_| anyhow!("Failed to send error"))?;
141            }
142        }
143
144        Ok(())
145    }
146
147    /// Set up SSE streaming from the response
148    async fn setup_sse_stream(&self, mut response: Response<AsyncBody>) -> Result<()> {
149        let response_tx = self.response_tx.clone();
150        let error_tx = self.error_tx.clone();
151
152        // Spawn a task to handle the SSE stream
153        smol::spawn(async move {
154            let reader = futures::io::BufReader::new(response.body_mut());
155            let mut lines = futures::AsyncBufReadExt::lines(reader);
156
157            let mut data_buffer = Vec::new();
158            let mut in_message = false;
159
160            while let Some(line_result) = lines.next().await {
161                match line_result {
162                    Ok(line) => {
163                        if line.is_empty() {
164                            // Empty line signals end of event
165                            if !data_buffer.is_empty() {
166                                let message = data_buffer.join("\n");
167
168                                // Filter out ping messages and empty data
169                                if !message.trim().is_empty() && message != "ping" {
170                                    if let Err(e) = response_tx.send(message).await {
171                                        log::error!("Failed to send SSE message: {}", e);
172                                        break;
173                                    }
174                                }
175                                data_buffer.clear();
176                            }
177                            in_message = false;
178                        } else if let Some(data) = line.strip_prefix("data: ") {
179                            // Handle data lines
180                            let data = data.trim();
181                            if !data.is_empty() {
182                                // Check if this is a ping message
183                                if data == "ping" {
184                                    log::trace!("Received SSE ping");
185                                    continue;
186                                }
187                                data_buffer.push(data.to_string());
188                                in_message = true;
189                            }
190                        } else if line.starts_with("event:")
191                            || line.starts_with("id:")
192                            || line.starts_with("retry:")
193                        {
194                            // Ignore other SSE fields
195                            continue;
196                        } else if in_message {
197                            // Continuation of data
198                            data_buffer.push(line);
199                        }
200                    }
201                    Err(e) => {
202                        let _ = error_tx.send(format!("SSE stream error: {}", e)).await;
203                        break;
204                    }
205                }
206            }
207        })
208        .detach();
209
210        Ok(())
211    }
212}
213
214#[async_trait]
215impl Transport for HttpTransport {
216    async fn send(&self, message: String) -> Result<()> {
217        self.send_message(message).await
218    }
219
220    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
221        Box::pin(self.response_rx.clone())
222    }
223
224    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
225        Box::pin(self.error_rx.clone())
226    }
227}
228
229impl Drop for HttpTransport {
230    fn drop(&mut self) {
231        // Try to cleanup session on drop
232        let http_client = self.http_client.clone();
233        let endpoint = self.endpoint.clone();
234        let session_id = self.session_id.lock().clone();
235        let headers = self.headers.clone();
236
237        if let Some(session_id) = session_id {
238            self.executor
239                .spawn(async move {
240                    let mut request_builder = Request::builder()
241                        .method(Method::DELETE)
242                        .uri(&endpoint)
243                        .header(HEADER_SESSION_ID, &session_id);
244
245                    // Add authentication headers if present
246                    for (key, value) in headers {
247                        request_builder = request_builder.header(key.as_str(), value.as_str());
248                    }
249
250                    let request = request_builder.body(AsyncBody::empty());
251
252                    if let Ok(request) = request {
253                        let _ = http_client.send(request).await;
254                    }
255                })
256                .detach();
257        }
258    }
259}