test.rs

  1use anyhow::Context as _;
  2use collections::HashMap;
  3use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
  4use gpui::BackgroundExecutor;
  5use std::{pin::Pin, sync::Arc};
  6
  7use crate::{
  8    transport::Transport,
  9    types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
 10};
 11
 12pub fn create_fake_transport(
 13    name: impl Into<String>,
 14    executor: BackgroundExecutor,
 15) -> FakeTransport {
 16    let name = name.into();
 17    FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
 18        move |_params| {
 19            let name = name.clone();
 20            async move { create_initialize_response(name.clone()) }
 21        },
 22    )
 23}
 24
 25fn create_initialize_response(server_name: String) -> InitializeResponse {
 26    InitializeResponse {
 27        protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
 28        server_info: Implementation {
 29            name: server_name,
 30            version: "1.0.0".to_string(),
 31        },
 32        capabilities: ServerCapabilities::default(),
 33        meta: None,
 34    }
 35}
 36
 37pub struct FakeTransport {
 38    request_handlers: HashMap<
 39        &'static str,
 40        Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
 41    >,
 42    tx: futures::channel::mpsc::UnboundedSender<String>,
 43    rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
 44    executor: BackgroundExecutor,
 45}
 46
 47impl FakeTransport {
 48    pub fn new(executor: BackgroundExecutor) -> Self {
 49        let (tx, rx) = futures::channel::mpsc::unbounded();
 50        Self {
 51            request_handlers: Default::default(),
 52            tx,
 53            rx: Arc::new(Mutex::new(rx)),
 54            executor,
 55        }
 56    }
 57
 58    pub fn on_request<T, Fut>(
 59        mut self,
 60        handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
 61    ) -> Self
 62    where
 63        T: crate::types::Request,
 64        Fut: 'static + Send + Future<Output = T::Response>,
 65    {
 66        self.request_handlers.insert(
 67            T::METHOD,
 68            Arc::new(move |value| {
 69                let params = value
 70                    .get("params")
 71                    .cloned()
 72                    .unwrap_or(serde_json::Value::Null);
 73                let params: T::Params =
 74                    serde_json::from_value(params).expect("Invalid parameters received");
 75                let response = handler(params);
 76                async move { serde_json::to_value(response.await).unwrap() }.boxed()
 77            }),
 78        );
 79        self
 80    }
 81}
 82
 83#[async_trait::async_trait]
 84impl Transport for FakeTransport {
 85    async fn send(&self, message: String) -> anyhow::Result<()> {
 86        if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
 87            let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
 88
 89            if let Some(method) = msg.get("method") {
 90                let method = method.as_str().expect("Invalid method received");
 91                if let Some(handler) = self.request_handlers.get(method) {
 92                    let payload = handler(msg).await;
 93                    let response = serde_json::json!({
 94                        "jsonrpc": "2.0",
 95                        "id": id,
 96                        "result": payload
 97                    });
 98                    self.tx
 99                        .unbounded_send(response.to_string())
100                        .context("sending a message")?;
101                } else {
102                    log::debug!("No handler registered for MCP request '{method}'");
103                }
104            }
105        }
106        Ok(())
107    }
108
109    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
110        let rx = self.rx.clone();
111        let executor = self.executor.clone();
112        Box::pin(futures::stream::unfold(rx, move |rx| {
113            let executor = executor.clone();
114            async move {
115                let mut rx_guard = rx.lock().await;
116                executor.simulate_random_delay().await;
117                if let Some(message) = rx_guard.next().await {
118                    drop(rx_guard);
119                    Some((message, rx))
120                } else {
121                    None
122                }
123            }
124        }))
125    }
126
127    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
128        Box::pin(futures::stream::empty())
129    }
130}