1use anyhow::Context as _;
2use collections::HashMap;
3use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
4use gpui::BackgroundExecutor;
5use std::{pin::Pin, sync::Arc};
6
7use crate::{
8 transport::Transport,
9 types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
10};
11
12pub fn create_fake_transport(
13 name: impl Into<String>,
14 executor: BackgroundExecutor,
15) -> FakeTransport {
16 let name = name.into();
17 FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
18 move |_params| {
19 let name = name.clone();
20 async move { create_initialize_response(name.clone()) }
21 },
22 )
23}
24
25fn create_initialize_response(server_name: String) -> InitializeResponse {
26 InitializeResponse {
27 protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
28 server_info: Implementation {
29 name: server_name,
30 version: "1.0.0".to_string(),
31 },
32 capabilities: ServerCapabilities::default(),
33 meta: None,
34 }
35}
36
37pub struct FakeTransport {
38 request_handlers: HashMap<
39 &'static str,
40 Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
41 >,
42 tx: futures::channel::mpsc::UnboundedSender<String>,
43 rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
44 executor: BackgroundExecutor,
45}
46
47impl FakeTransport {
48 pub fn new(executor: BackgroundExecutor) -> Self {
49 let (tx, rx) = futures::channel::mpsc::unbounded();
50 Self {
51 request_handlers: Default::default(),
52 tx,
53 rx: Arc::new(Mutex::new(rx)),
54 executor,
55 }
56 }
57
58 pub fn on_request<T, Fut>(
59 mut self,
60 handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
61 ) -> Self
62 where
63 T: crate::types::Request,
64 Fut: 'static + Send + Future<Output = T::Response>,
65 {
66 self.request_handlers.insert(
67 T::METHOD,
68 Arc::new(move |value| {
69 let params = value
70 .get("params")
71 .cloned()
72 .unwrap_or(serde_json::Value::Null);
73 let params: T::Params =
74 serde_json::from_value(params).expect("Invalid parameters received");
75 let response = handler(params);
76 async move { serde_json::to_value(response.await).unwrap() }.boxed()
77 }),
78 );
79 self
80 }
81}
82
83#[async_trait::async_trait]
84impl Transport for FakeTransport {
85 async fn send(&self, message: String) -> anyhow::Result<()> {
86 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
87 let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
88
89 if let Some(method) = msg.get("method") {
90 let method = method.as_str().expect("Invalid method received");
91 if let Some(handler) = self.request_handlers.get(method) {
92 let payload = handler(msg).await;
93 let response = serde_json::json!({
94 "jsonrpc": "2.0",
95 "id": id,
96 "result": payload
97 });
98 self.tx
99 .unbounded_send(response.to_string())
100 .context("sending a message")?;
101 } else {
102 log::debug!("No handler registered for MCP request '{method}'");
103 }
104 }
105 }
106 Ok(())
107 }
108
109 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
110 let rx = self.rx.clone();
111 let executor = self.executor.clone();
112 Box::pin(futures::stream::unfold(rx, move |rx| {
113 let executor = executor.clone();
114 async move {
115 let mut rx_guard = rx.lock().await;
116 executor.simulate_random_delay().await;
117 if let Some(message) = rx_guard.next().await {
118 drop(rx_guard);
119 Some((message, rx))
120 } else {
121 None
122 }
123 }
124 }))
125 }
126
127 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
128 Box::pin(futures::stream::empty())
129 }
130}