test.rs

  1use anyhow::Context as _;
  2use collections::HashMap;
  3use futures::{Stream, StreamExt as _, lock::Mutex};
  4use gpui::BackgroundExecutor;
  5use std::{pin::Pin, sync::Arc};
  6
  7use crate::{
  8    transport::Transport,
  9    types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
 10};
 11
 12pub fn create_fake_transport(
 13    name: impl Into<String>,
 14    executor: BackgroundExecutor,
 15) -> FakeTransport {
 16    let name = name.into();
 17    FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
 18        create_initialize_response(name.clone())
 19    })
 20}
 21
 22fn create_initialize_response(server_name: String) -> InitializeResponse {
 23    InitializeResponse {
 24        protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
 25        server_info: Implementation {
 26            name: server_name,
 27            version: "1.0.0".to_string(),
 28        },
 29        capabilities: ServerCapabilities::default(),
 30        meta: None,
 31    }
 32}
 33
 34pub struct FakeTransport {
 35    request_handlers:
 36        HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
 37    tx: futures::channel::mpsc::UnboundedSender<String>,
 38    rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
 39    executor: BackgroundExecutor,
 40}
 41
 42impl FakeTransport {
 43    pub fn new(executor: BackgroundExecutor) -> Self {
 44        let (tx, rx) = futures::channel::mpsc::unbounded();
 45        Self {
 46            request_handlers: Default::default(),
 47            tx,
 48            rx: Arc::new(Mutex::new(rx)),
 49            executor,
 50        }
 51    }
 52
 53    pub fn on_request<T: crate::types::Request>(
 54        mut self,
 55        handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
 56    ) -> Self {
 57        self.request_handlers.insert(
 58            T::METHOD,
 59            Arc::new(move |value| {
 60                let params = value.get("params").expect("Missing parameters").clone();
 61                let params: T::Params =
 62                    serde_json::from_value(params).expect("Invalid parameters received");
 63                let response = handler(params);
 64                serde_json::to_value(response).unwrap()
 65            }),
 66        );
 67        self
 68    }
 69}
 70
 71#[async_trait::async_trait]
 72impl Transport for FakeTransport {
 73    async fn send(&self, message: String) -> anyhow::Result<()> {
 74        if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
 75            let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
 76
 77            if let Some(method) = msg.get("method") {
 78                let method = method.as_str().expect("Invalid method received");
 79                if let Some(handler) = self.request_handlers.get(method) {
 80                    let payload = handler(msg);
 81                    let response = serde_json::json!({
 82                        "jsonrpc": "2.0",
 83                        "id": id,
 84                        "result": payload
 85                    });
 86                    self.tx
 87                        .unbounded_send(response.to_string())
 88                        .context("sending a message")?;
 89                } else {
 90                    log::debug!("No handler registered for MCP request '{method}'");
 91                }
 92            }
 93        }
 94        Ok(())
 95    }
 96
 97    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
 98        let rx = self.rx.clone();
 99        let executor = self.executor.clone();
100        Box::pin(futures::stream::unfold(rx, move |rx| {
101            let executor = executor.clone();
102            async move {
103                let mut rx_guard = rx.lock().await;
104                executor.simulate_random_delay().await;
105                if let Some(message) = rx_guard.next().await {
106                    drop(rx_guard);
107                    Some((message, rx))
108                } else {
109                    None
110                }
111            }
112        }))
113    }
114
115    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
116        Box::pin(futures::stream::empty())
117    }
118}