1use anyhow::Context as _;
2use collections::HashMap;
3use futures::{Stream, StreamExt as _, lock::Mutex};
4use gpui::BackgroundExecutor;
5use std::{pin::Pin, sync::Arc};
6
7use crate::{
8 transport::Transport,
9 types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
10};
11
12pub fn create_fake_transport(
13 name: impl Into<String>,
14 executor: BackgroundExecutor,
15) -> FakeTransport {
16 let name = name.into();
17 FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
18 create_initialize_response(name.clone())
19 })
20}
21
22fn create_initialize_response(server_name: String) -> InitializeResponse {
23 InitializeResponse {
24 protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
25 server_info: Implementation {
26 name: server_name,
27 version: "1.0.0".to_string(),
28 },
29 capabilities: ServerCapabilities::default(),
30 meta: None,
31 }
32}
33
34pub struct FakeTransport {
35 request_handlers:
36 HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
37 tx: futures::channel::mpsc::UnboundedSender<String>,
38 rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
39 executor: BackgroundExecutor,
40}
41
42impl FakeTransport {
43 pub fn new(executor: BackgroundExecutor) -> Self {
44 let (tx, rx) = futures::channel::mpsc::unbounded();
45 Self {
46 request_handlers: Default::default(),
47 tx,
48 rx: Arc::new(Mutex::new(rx)),
49 executor,
50 }
51 }
52
53 pub fn on_request<T: crate::types::Request>(
54 mut self,
55 handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
56 ) -> Self {
57 self.request_handlers.insert(
58 T::METHOD,
59 Arc::new(move |value| {
60 let params = value.get("params").expect("Missing parameters").clone();
61 let params: T::Params =
62 serde_json::from_value(params).expect("Invalid parameters received");
63 let response = handler(params);
64 serde_json::to_value(response).unwrap()
65 }),
66 );
67 self
68 }
69}
70
71#[async_trait::async_trait]
72impl Transport for FakeTransport {
73 async fn send(&self, message: String) -> anyhow::Result<()> {
74 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
75 let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
76
77 if let Some(method) = msg.get("method") {
78 let method = method.as_str().expect("Invalid method received");
79 if let Some(handler) = self.request_handlers.get(method) {
80 let payload = handler(msg);
81 let response = serde_json::json!({
82 "jsonrpc": "2.0",
83 "id": id,
84 "result": payload
85 });
86 self.tx
87 .unbounded_send(response.to_string())
88 .context("sending a message")?;
89 } else {
90 log::debug!("No handler registered for MCP request '{method}'");
91 }
92 }
93 }
94 Ok(())
95 }
96
97 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
98 let rx = self.rx.clone();
99 let executor = self.executor.clone();
100 Box::pin(futures::stream::unfold(rx, move |rx| {
101 let executor = executor.clone();
102 async move {
103 let mut rx_guard = rx.lock().await;
104 executor.simulate_random_delay().await;
105 if let Some(message) = rx_guard.next().await {
106 drop(rx_guard);
107 Some((message, rx))
108 } else {
109 None
110 }
111 }
112 }))
113 }
114
115 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
116 Box::pin(futures::stream::empty())
117 }
118}