http.rs

  1use anyhow::{Result, anyhow};
  2use async_trait::async_trait;
  3use collections::HashMap;
  4use futures::{Stream, StreamExt};
  5use gpui::BackgroundExecutor;
  6use http_client::{AsyncBody, HttpClient, Request, Response, http::Method};
  7use parking_lot::Mutex as SyncMutex;
  8use smol::channel;
  9use std::{pin::Pin, sync::Arc};
 10
 11use crate::oauth::{self, OAuthTokenProvider, WwwAuthenticate};
 12use crate::transport::Transport;
 13
 14/// Typed errors returned by the HTTP transport that callers can downcast from
 15/// `anyhow::Error` to handle specific failure modes.
 16#[derive(Debug)]
 17pub enum TransportError {
 18    /// The server returned 401 and token refresh either wasn't possible or
 19    /// failed. The caller should initiate the OAuth authorization flow.
 20    AuthRequired { www_authenticate: WwwAuthenticate },
 21}
 22
 23impl std::fmt::Display for TransportError {
 24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 25        match self {
 26            TransportError::AuthRequired { .. } => {
 27                write!(f, "OAuth authorization required")
 28            }
 29        }
 30    }
 31}
 32
 33impl std::error::Error for TransportError {}
 34
 35// Constants from MCP spec
 36const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
 37const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
 38const JSON_MIME_TYPE: &str = "application/json";
 39
 40/// HTTP Transport with session management and SSE support
 41pub struct HttpTransport {
 42    http_client: Arc<dyn HttpClient>,
 43    endpoint: String,
 44    session_id: Arc<SyncMutex<Option<String>>>,
 45    executor: BackgroundExecutor,
 46    response_tx: channel::Sender<String>,
 47    response_rx: channel::Receiver<String>,
 48    error_tx: channel::Sender<String>,
 49    error_rx: channel::Receiver<String>,
 50    /// Static headers to include in every request (e.g. from server config).
 51    headers: HashMap<String, String>,
 52    /// When set, the transport attaches `Authorization: Bearer` headers and
 53    /// handles 401 responses with token refresh + retry.
 54    token_provider: Option<Arc<dyn OAuthTokenProvider>>,
 55}
 56
 57impl HttpTransport {
 58    pub fn new(
 59        http_client: Arc<dyn HttpClient>,
 60        endpoint: String,
 61        headers: HashMap<String, String>,
 62        executor: BackgroundExecutor,
 63    ) -> Self {
 64        Self::new_with_token_provider(http_client, endpoint, headers, executor, None)
 65    }
 66
 67    pub fn new_with_token_provider(
 68        http_client: Arc<dyn HttpClient>,
 69        endpoint: String,
 70        headers: HashMap<String, String>,
 71        executor: BackgroundExecutor,
 72        token_provider: Option<Arc<dyn OAuthTokenProvider>>,
 73    ) -> Self {
 74        let (response_tx, response_rx) = channel::unbounded();
 75        let (error_tx, error_rx) = channel::unbounded();
 76
 77        Self {
 78            http_client,
 79            executor,
 80            endpoint,
 81            session_id: Arc::new(SyncMutex::new(None)),
 82            response_tx,
 83            response_rx,
 84            error_tx,
 85            error_rx,
 86            headers,
 87            token_provider,
 88        }
 89    }
 90
 91    /// Build a POST request for the given message body, attaching all standard
 92    /// headers (content-type, accept, session ID, static headers, and bearer
 93    /// token if available).
 94    fn build_request(&self, message: &[u8]) -> Result<http_client::Request<AsyncBody>> {
 95        let mut request_builder = Request::builder()
 96            .method(Method::POST)
 97            .uri(&self.endpoint)
 98            .header("Content-Type", JSON_MIME_TYPE)
 99            .header(
100                "Accept",
101                format!("{}, {}", JSON_MIME_TYPE, EVENT_STREAM_MIME_TYPE),
102            );
103
104        for (key, value) in &self.headers {
105            request_builder = request_builder.header(key.as_str(), value.as_str());
106        }
107
108        // Attach bearer token when a token provider is present.
109        if let Some(token) = self.token_provider.as_ref().and_then(|p| p.access_token()) {
110            request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
111        }
112
113        // Add session ID if we have one (except for initialize).
114        if let Some(ref session_id) = *self.session_id.lock() {
115            request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
116        }
117
118        Ok(request_builder.body(AsyncBody::from(message.to_vec()))?)
119    }
120
121    /// Send a message and handle the response based on content type.
122    async fn send_message(&self, message: String) -> Result<()> {
123        let is_notification =
124            !message.contains("\"id\":") || message.contains("notifications/initialized");
125
126        // If we currently have no access token, try refreshing before sending
127        // the request so restored but expired sessions do not need an initial
128        // 401 round-trip before they can recover.
129        if let Some(ref provider) = self.token_provider {
130            if provider.access_token().is_none() {
131                provider.try_refresh().await.unwrap_or(false);
132            }
133        }
134
135        let request = self.build_request(message.as_bytes())?;
136        let mut response = self.http_client.send(request).await?;
137
138        // On 401, try refreshing the token and retry once.
139        if response.status().as_u16() == 401 {
140            let www_auth_header = response
141                .headers()
142                .get("www-authenticate")
143                .and_then(|v| v.to_str().ok())
144                .unwrap_or("Bearer");
145
146            let www_authenticate =
147                oauth::parse_www_authenticate(www_auth_header).unwrap_or(WwwAuthenticate {
148                    resource_metadata: None,
149                    scope: None,
150                    error: None,
151                    error_description: None,
152                });
153
154            if let Some(ref provider) = self.token_provider {
155                if provider.try_refresh().await.unwrap_or(false) {
156                    // Retry with the refreshed token.
157                    let retry_request = self.build_request(message.as_bytes())?;
158                    response = self.http_client.send(retry_request).await?;
159
160                    // If still 401 after refresh, give up.
161                    if response.status().as_u16() == 401 {
162                        return Err(TransportError::AuthRequired { www_authenticate }.into());
163                    }
164                } else {
165                    return Err(TransportError::AuthRequired { www_authenticate }.into());
166                }
167            } else {
168                return Err(TransportError::AuthRequired { www_authenticate }.into());
169            }
170        }
171
172        // Handle different response types based on status and content-type.
173        match response.status() {
174            status if status.is_success() => {
175                // Check content type
176                let content_type = response
177                    .headers()
178                    .get("content-type")
179                    .and_then(|v| v.to_str().ok());
180
181                // Extract session ID from response headers if present
182                if let Some(session_id) = response
183                    .headers()
184                    .get(HEADER_SESSION_ID)
185                    .and_then(|v| v.to_str().ok())
186                {
187                    *self.session_id.lock() = Some(session_id.to_string());
188                    log::debug!("Session ID set: {}", session_id);
189                }
190
191                match content_type {
192                    Some(ct) if ct.starts_with(JSON_MIME_TYPE) => {
193                        // JSON response - read and forward immediately
194                        let mut body = String::new();
195                        futures::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
196                            .await?;
197
198                        // Only send non-empty responses
199                        if !body.is_empty() {
200                            self.response_tx
201                                .send(body)
202                                .await
203                                .map_err(|_| anyhow!("Failed to send JSON response"))?;
204                        }
205                    }
206                    Some(ct) if ct.starts_with(EVENT_STREAM_MIME_TYPE) => {
207                        // SSE stream - set up streaming
208                        self.setup_sse_stream(response).await?;
209                    }
210                    _ => {
211                        // For notifications, 202 Accepted with no content type is ok
212                        if is_notification && status.as_u16() == 202 {
213                            log::debug!("Notification accepted");
214                        } else {
215                            return Err(anyhow!("Unexpected content type: {:?}", content_type));
216                        }
217                    }
218                }
219            }
220            status if status.as_u16() == 202 => {
221                // Accepted - notification acknowledged, no response needed
222                log::debug!("Notification accepted");
223            }
224            _ => {
225                let mut error_body = String::new();
226                futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;
227
228                self.error_tx
229                    .send(format!("HTTP {}: {}", response.status(), error_body))
230                    .await
231                    .map_err(|_| anyhow!("Failed to send error"))?;
232            }
233        }
234
235        Ok(())
236    }
237
238    /// Set up SSE streaming from the response
239    async fn setup_sse_stream(&self, mut response: Response<AsyncBody>) -> Result<()> {
240        let response_tx = self.response_tx.clone();
241        let error_tx = self.error_tx.clone();
242
243        // Spawn a task to handle the SSE stream
244        smol::spawn(async move {
245            let reader = futures::io::BufReader::new(response.body_mut());
246            let mut lines = futures::AsyncBufReadExt::lines(reader);
247
248            let mut data_buffer = Vec::new();
249            let mut in_message = false;
250
251            while let Some(line_result) = lines.next().await {
252                match line_result {
253                    Ok(line) => {
254                        if line.is_empty() {
255                            // Empty line signals end of event
256                            if !data_buffer.is_empty() {
257                                let message = data_buffer.join("\n");
258
259                                // Filter out ping messages and empty data
260                                if !message.trim().is_empty() && message != "ping" {
261                                    if let Err(e) = response_tx.send(message).await {
262                                        log::error!("Failed to send SSE message: {}", e);
263                                        break;
264                                    }
265                                }
266                                data_buffer.clear();
267                            }
268                            in_message = false;
269                        } else if let Some(data) = line.strip_prefix("data: ") {
270                            // Handle data lines
271                            let data = data.trim();
272                            if !data.is_empty() {
273                                // Check if this is a ping message
274                                if data == "ping" {
275                                    log::trace!("Received SSE ping");
276                                    continue;
277                                }
278                                data_buffer.push(data.to_string());
279                                in_message = true;
280                            }
281                        } else if line.starts_with("event:")
282                            || line.starts_with("id:")
283                            || line.starts_with("retry:")
284                        {
285                            // Ignore other SSE fields
286                            continue;
287                        } else if in_message {
288                            // Continuation of data
289                            data_buffer.push(line);
290                        }
291                    }
292                    Err(e) => {
293                        let _ = error_tx.send(format!("SSE stream error: {}", e)).await;
294                        break;
295                    }
296                }
297            }
298        })
299        .detach();
300
301        Ok(())
302    }
303}
304
305#[async_trait]
306impl Transport for HttpTransport {
307    async fn send(&self, message: String) -> Result<()> {
308        self.send_message(message).await
309    }
310
311    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
312        Box::pin(self.response_rx.clone())
313    }
314
315    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
316        Box::pin(self.error_rx.clone())
317    }
318}
319
320impl Drop for HttpTransport {
321    fn drop(&mut self) {
322        // Try to cleanup session on drop
323        let http_client = self.http_client.clone();
324        let endpoint = self.endpoint.clone();
325        let session_id = self.session_id.lock().clone();
326        let headers = self.headers.clone();
327        let access_token = self.token_provider.as_ref().and_then(|p| p.access_token());
328
329        if let Some(session_id) = session_id {
330            self.executor
331                .spawn(async move {
332                    let mut request_builder = Request::builder()
333                        .method(Method::DELETE)
334                        .uri(&endpoint)
335                        .header(HEADER_SESSION_ID, &session_id);
336
337                    // Add static authentication headers.
338                    for (key, value) in headers {
339                        request_builder = request_builder.header(key.as_str(), value.as_str());
340                    }
341
342                    // Attach bearer token if available.
343                    if let Some(token) = access_token {
344                        request_builder =
345                            request_builder.header("Authorization", format!("Bearer {}", token));
346                    }
347
348                    let request = request_builder.body(AsyncBody::empty());
349
350                    if let Ok(request) = request {
351                        let _ = http_client.send(request).await;
352                    }
353                })
354                .detach();
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use async_trait::async_trait;
363    use gpui::TestAppContext;
364    use parking_lot::Mutex as SyncMutex;
365    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
366
367    /// A mock token provider that returns a configurable token and tracks
368    /// refresh attempts.
369    struct FakeTokenProvider {
370        token: SyncMutex<Option<String>>,
371        refreshed_token: SyncMutex<Option<String>>,
372        refresh_succeeds: AtomicBool,
373        refresh_count: AtomicUsize,
374    }
375
376    impl FakeTokenProvider {
377        fn new(token: Option<&str>, refresh_succeeds: bool) -> Arc<Self> {
378            Self::with_refreshed_token(token, None, refresh_succeeds)
379        }
380
381        fn with_refreshed_token(
382            token: Option<&str>,
383            refreshed_token: Option<&str>,
384            refresh_succeeds: bool,
385        ) -> Arc<Self> {
386            Arc::new(Self {
387                token: SyncMutex::new(token.map(String::from)),
388                refreshed_token: SyncMutex::new(refreshed_token.map(String::from)),
389                refresh_succeeds: AtomicBool::new(refresh_succeeds),
390                refresh_count: AtomicUsize::new(0),
391            })
392        }
393
394        fn set_token(&self, token: &str) {
395            *self.token.lock() = Some(token.to_string());
396        }
397
398        fn refresh_count(&self) -> usize {
399            self.refresh_count.load(Ordering::SeqCst)
400        }
401    }
402
403    #[async_trait]
404    impl OAuthTokenProvider for FakeTokenProvider {
405        fn access_token(&self) -> Option<String> {
406            self.token.lock().clone()
407        }
408
409        async fn try_refresh(&self) -> Result<bool> {
410            self.refresh_count.fetch_add(1, Ordering::SeqCst);
411
412            let refresh_succeeds = self.refresh_succeeds.load(Ordering::SeqCst);
413            if refresh_succeeds {
414                if let Some(token) = self.refreshed_token.lock().clone() {
415                    *self.token.lock() = Some(token);
416                }
417            }
418
419            Ok(refresh_succeeds)
420        }
421    }
422
423    fn make_fake_http_client(
424        handler: impl Fn(
425            http_client::Request<AsyncBody>,
426        ) -> std::pin::Pin<
427            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
428        > + Send
429        + Sync
430        + 'static,
431    ) -> Arc<dyn HttpClient> {
432        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
433    }
434
435    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
436        Ok(Response::builder()
437            .status(status)
438            .header("Content-Type", "application/json")
439            .body(AsyncBody::from(body.as_bytes().to_vec()))
440            .unwrap())
441    }
442
443    #[gpui::test]
444    async fn test_bearer_token_attached_to_requests(cx: &mut TestAppContext) {
445        // Capture the Authorization header from the request.
446        let captured_auth = Arc::new(SyncMutex::new(None::<String>));
447        let captured_auth_clone = captured_auth.clone();
448
449        let client = make_fake_http_client(move |req| {
450            let auth = req
451                .headers()
452                .get("Authorization")
453                .map(|v| v.to_str().unwrap().to_string());
454            *captured_auth_clone.lock() = auth;
455            Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
456        });
457
458        let provider = FakeTokenProvider::new(Some("test-access-token"), false);
459        let transport = HttpTransport::new_with_token_provider(
460            client,
461            "http://mcp.example.com/mcp".to_string(),
462            HashMap::default(),
463            cx.background_executor.clone(),
464            Some(provider),
465        );
466
467        transport
468            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
469            .await
470            .expect("send should succeed");
471
472        assert_eq!(
473            captured_auth.lock().as_deref(),
474            Some("Bearer test-access-token"),
475        );
476    }
477
478    #[gpui::test]
479    async fn test_no_bearer_token_without_provider(cx: &mut TestAppContext) {
480        let captured_auth = Arc::new(SyncMutex::new(None::<String>));
481        let captured_auth_clone = captured_auth.clone();
482
483        let client = make_fake_http_client(move |req| {
484            let auth = req
485                .headers()
486                .get("Authorization")
487                .map(|v| v.to_str().unwrap().to_string());
488            *captured_auth_clone.lock() = auth;
489            Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
490        });
491
492        let transport = HttpTransport::new(
493            client,
494            "http://mcp.example.com/mcp".to_string(),
495            HashMap::default(),
496            cx.background_executor.clone(),
497        );
498
499        transport
500            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
501            .await
502            .expect("send should succeed");
503
504        assert!(captured_auth.lock().is_none());
505    }
506
507    #[gpui::test]
508    async fn test_missing_token_triggers_refresh_before_first_request(cx: &mut TestAppContext) {
509        let captured_auth = Arc::new(SyncMutex::new(None::<String>));
510        let captured_auth_clone = captured_auth.clone();
511
512        let client = make_fake_http_client(move |req| {
513            let auth = req
514                .headers()
515                .get("Authorization")
516                .map(|v| v.to_str().unwrap().to_string());
517            *captured_auth_clone.lock() = auth;
518            Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
519        });
520
521        let provider = FakeTokenProvider::with_refreshed_token(None, Some("refreshed-token"), true);
522        let transport = HttpTransport::new_with_token_provider(
523            client,
524            "http://mcp.example.com/mcp".to_string(),
525            HashMap::default(),
526            cx.background_executor.clone(),
527            Some(provider.clone()),
528        );
529
530        transport
531            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
532            .await
533            .expect("send should succeed after proactive refresh");
534
535        assert_eq!(provider.refresh_count(), 1);
536        assert_eq!(
537            captured_auth.lock().as_deref(),
538            Some("Bearer refreshed-token"),
539        );
540    }
541
542    #[gpui::test]
543    async fn test_invalid_token_still_triggers_refresh_and_retry(cx: &mut TestAppContext) {
544        let request_count = Arc::new(AtomicUsize::new(0));
545        let request_count_clone = request_count.clone();
546
547        let client = make_fake_http_client(move |_req| {
548            let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
549            Box::pin(async move {
550                if count == 0 {
551                    Ok(Response::builder()
552                        .status(401)
553                        .header(
554                            "WWW-Authenticate",
555                            r#"Bearer error="invalid_token", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
556                        )
557                        .body(AsyncBody::from(b"Unauthorized".to_vec()))
558                        .unwrap())
559                } else {
560                    json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
561                }
562            })
563        });
564
565        let provider = FakeTokenProvider::with_refreshed_token(
566            Some("old-token"),
567            Some("refreshed-token"),
568            true,
569        );
570        let transport = HttpTransport::new_with_token_provider(
571            client,
572            "http://mcp.example.com/mcp".to_string(),
573            HashMap::default(),
574            cx.background_executor.clone(),
575            Some(provider.clone()),
576        );
577
578        transport
579            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
580            .await
581            .expect("send should succeed after refresh");
582
583        assert_eq!(provider.refresh_count(), 1);
584        assert_eq!(request_count.load(Ordering::SeqCst), 2);
585    }
586
587    #[gpui::test]
588    async fn test_401_triggers_refresh_and_retry(cx: &mut TestAppContext) {
589        let request_count = Arc::new(AtomicUsize::new(0));
590        let request_count_clone = request_count.clone();
591
592        let client = make_fake_http_client(move |_req| {
593            let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
594            Box::pin(async move {
595                if count == 0 {
596                    // First request: 401.
597                    Ok(Response::builder()
598                        .status(401)
599                        .header(
600                            "WWW-Authenticate",
601                            r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
602                        )
603                        .body(AsyncBody::from(b"Unauthorized".to_vec()))
604                        .unwrap())
605                } else {
606                    // Retry after refresh: 200.
607                    json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
608                }
609            })
610        });
611
612        let provider = FakeTokenProvider::new(Some("old-token"), true);
613        // Simulate the refresh updating the token.
614        let provider_ref = provider.clone();
615        let transport = HttpTransport::new_with_token_provider(
616            client,
617            "http://mcp.example.com/mcp".to_string(),
618            HashMap::default(),
619            cx.background_executor.clone(),
620            Some(provider.clone()),
621        );
622
623        // Set the new token that will be used on retry.
624        provider_ref.set_token("refreshed-token");
625
626        transport
627            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
628            .await
629            .expect("send should succeed after refresh");
630
631        assert_eq!(provider_ref.refresh_count(), 1);
632        assert_eq!(request_count.load(Ordering::SeqCst), 2);
633    }
634
635    #[gpui::test]
636    async fn test_401_returns_auth_required_when_refresh_fails(cx: &mut TestAppContext) {
637        let client = make_fake_http_client(|_req| {
638            Box::pin(async {
639                Ok(Response::builder()
640                    .status(401)
641                    .header(
642                        "WWW-Authenticate",
643                        r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="read write""#,
644                    )
645                    .body(AsyncBody::from(b"Unauthorized".to_vec()))
646                    .unwrap())
647            })
648        });
649
650        // Refresh returns false — no new token available.
651        let provider = FakeTokenProvider::new(Some("stale-token"), false);
652        let transport = HttpTransport::new_with_token_provider(
653            client,
654            "http://mcp.example.com/mcp".to_string(),
655            HashMap::default(),
656            cx.background_executor.clone(),
657            Some(provider.clone()),
658        );
659
660        let err = transport
661            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
662            .await
663            .unwrap_err();
664
665        let transport_err = err
666            .downcast_ref::<TransportError>()
667            .expect("error should be TransportError");
668        match transport_err {
669            TransportError::AuthRequired { www_authenticate } => {
670                assert_eq!(
671                    www_authenticate
672                        .resource_metadata
673                        .as_ref()
674                        .map(|u| u.as_str()),
675                    Some("https://mcp.example.com/.well-known/oauth-protected-resource"),
676                );
677                assert_eq!(
678                    www_authenticate.scope,
679                    Some(vec!["read".to_string(), "write".to_string()]),
680                );
681            }
682        }
683        assert_eq!(provider.refresh_count(), 1);
684    }
685
686    #[gpui::test]
687    async fn test_401_returns_auth_required_without_provider(cx: &mut TestAppContext) {
688        let client = make_fake_http_client(|_req| {
689            Box::pin(async {
690                Ok(Response::builder()
691                    .status(401)
692                    .header("WWW-Authenticate", "Bearer")
693                    .body(AsyncBody::from(b"Unauthorized".to_vec()))
694                    .unwrap())
695            })
696        });
697
698        // No token provider at all.
699        let transport = HttpTransport::new(
700            client,
701            "http://mcp.example.com/mcp".to_string(),
702            HashMap::default(),
703            cx.background_executor.clone(),
704        );
705
706        let err = transport
707            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
708            .await
709            .unwrap_err();
710
711        let transport_err = err
712            .downcast_ref::<TransportError>()
713            .expect("error should be TransportError");
714        match transport_err {
715            TransportError::AuthRequired { www_authenticate } => {
716                assert!(www_authenticate.resource_metadata.is_none());
717                assert!(www_authenticate.scope.is_none());
718            }
719        }
720    }
721
722    #[gpui::test]
723    async fn test_401_after_successful_refresh_still_returns_auth_required(
724        cx: &mut TestAppContext,
725    ) {
726        // Both requests return 401 — the server rejects the refreshed token too.
727        let client = make_fake_http_client(|_req| {
728            Box::pin(async {
729                Ok(Response::builder()
730                    .status(401)
731                    .header("WWW-Authenticate", "Bearer")
732                    .body(AsyncBody::from(b"Unauthorized".to_vec()))
733                    .unwrap())
734            })
735        });
736
737        let provider = FakeTokenProvider::new(Some("token"), true);
738        let transport = HttpTransport::new_with_token_provider(
739            client,
740            "http://mcp.example.com/mcp".to_string(),
741            HashMap::default(),
742            cx.background_executor.clone(),
743            Some(provider.clone()),
744        );
745
746        let err = transport
747            .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
748            .await
749            .unwrap_err();
750
751        err.downcast_ref::<TransportError>()
752            .expect("error should be TransportError");
753        // Refresh was attempted exactly once.
754        assert_eq!(provider.refresh_count(), 1);
755    }
756}