mock.rs

  1//! Mock transport for testing remote connections.
  2//!
  3//! This module provides a mock implementation of the `RemoteConnection` trait
  4//! that allows testing remote editing functionality without actual SSH/WSL/Docker
  5//! connections.
  6//!
  7//! # Usage
  8//!
  9//! ```rust,ignore
 10//! use remote::{MockConnection, RemoteClient};
 11//!
 12//! #[gpui::test]
 13//! async fn test_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
 14//!     let (opts, server_session) = MockConnection::new(cx, server_cx);
 15//!
 16//!     // Create the headless project (server side)
 17//!     server_cx.update(HeadlessProject::init);
 18//!     let _headless = server_cx.new(|cx| {
 19//!         HeadlessProject::new(
 20//!             HeadlessAppState { session: server_session, /* ... */ },
 21//!             false,
 22//!             cx,
 23//!         )
 24//!     });
 25//!
 26//!     // Create the client using the helper
 27//!     let (client, server_client) = RemoteClient::new_mock(cx, server_cx).await;
 28//!     // ... test logic ...
 29//! }
 30//! ```
 31
 32use crate::remote_client::{
 33    ChannelClient, CommandTemplate, Interactive, RemoteClientDelegate, RemoteConnection,
 34    RemoteConnectionOptions,
 35};
 36use anyhow::Result;
 37use async_trait::async_trait;
 38use collections::HashMap;
 39use futures::{
 40    FutureExt, SinkExt, StreamExt,
 41    channel::{
 42        mpsc::{self, Sender},
 43        oneshot,
 44    },
 45    select_biased,
 46};
 47use gpui::{App, AppContext as _, AsyncApp, Global, Task, TestAppContext};
 48use rpc::{AnyProtoClient, proto::Envelope};
 49use std::{
 50    path::PathBuf,
 51    sync::{
 52        Arc,
 53        atomic::{AtomicU64, Ordering},
 54    },
 55};
 56use util::paths::{PathStyle, RemotePathBuf};
 57
 58/// Unique identifier for a mock connection.
 59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 60pub struct MockConnectionOptions {
 61    pub id: u64,
 62}
 63
 64/// A mock implementation of `RemoteConnection` for testing.
 65pub struct MockRemoteConnection {
 66    options: MockConnectionOptions,
 67    server_channel: Arc<ChannelClient>,
 68    server_cx: SendableCx,
 69}
 70
 71/// Wrapper to pass `AsyncApp` across thread boundaries in tests.
 72///
 73/// # Safety
 74///
 75/// This is safe because in test mode, GPUI is always single-threaded and so
 76/// having access to one async app means being on the same main thread.
 77pub(crate) struct SendableCx(AsyncApp);
 78
 79impl SendableCx {
 80    pub(crate) fn new(cx: &TestAppContext) -> Self {
 81        Self(cx.to_async())
 82    }
 83
 84    pub(crate) fn get(&self, _: &AsyncApp) -> AsyncApp {
 85        self.0.clone()
 86    }
 87}
 88
 89// SAFETY: In test mode, GPUI is always single-threaded, and SendableCx
 90// is only accessed from the main thread via the get() method which
 91// requires a valid AsyncApp reference.
 92unsafe impl Send for SendableCx {}
 93unsafe impl Sync for SendableCx {}
 94
 95/// Global registry that holds pre-created mock connections.
 96///
 97/// When `ConnectionPool::connect` is called with `MockConnectionOptions`,
 98/// it retrieves the connection from this registry.
 99#[derive(Default)]
100pub struct MockConnectionRegistry {
101    pending: HashMap<u64, (oneshot::Receiver<()>, Arc<MockRemoteConnection>)>,
102}
103
104impl Global for MockConnectionRegistry {}
105
106impl MockConnectionRegistry {
107    /// Called by `ConnectionPool::connect` to retrieve a pre-registered mock connection.
108    pub fn take(
109        &mut self,
110        opts: &MockConnectionOptions,
111    ) -> Option<impl Future<Output = Arc<MockRemoteConnection>> + use<>> {
112        let (guard, con) = self.pending.remove(&opts.id)?;
113        Some(async move {
114            _ = guard.await;
115            con
116        })
117    }
118}
119
120/// Helper for creating mock connection pairs in tests.
121pub struct MockConnection;
122
123pub type ConnectGuard = oneshot::Sender<()>;
124
125impl MockConnection {
126    /// Creates a new mock connection pair for testing.
127    ///
128    /// This function:
129    /// 1. Creates a unique `MockConnectionOptions` identifier
130    /// 2. Sets up the server-side channel (returned as `AnyProtoClient`)
131    /// 3. Creates a `MockRemoteConnection` and registers it in the global registry
132    /// 4. The connection will be retrieved from the registry when `ConnectionPool::connect` is called
133    ///
134    /// Returns:
135    /// - `MockConnectionOptions` to pass to `remote::connect()` or `RemoteClient` creation
136    /// - `AnyProtoClient` to pass to `HeadlessProject::new()` as the session
137    ///
138    /// # Arguments
139    /// - `client_cx`: The test context for the client side
140    /// - `server_cx`: The test context for the server/headless side
141    pub(crate) fn new(
142        client_cx: &mut TestAppContext,
143        server_cx: &mut TestAppContext,
144    ) -> (MockConnectionOptions, AnyProtoClient, ConnectGuard) {
145        static NEXT_ID: AtomicU64 = AtomicU64::new(0);
146        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
147        let opts = MockConnectionOptions { id };
148
149        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
150        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
151        let server_client = server_cx
152            .update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "mock-server", false));
153
154        let connection = Arc::new(MockRemoteConnection {
155            options: opts.clone(),
156            server_channel: server_client.clone(),
157            server_cx: SendableCx::new(server_cx),
158        });
159
160        let (tx, rx) = oneshot::channel();
161
162        client_cx.update(|cx| {
163            cx.default_global::<MockConnectionRegistry>()
164                .pending
165                .insert(opts.id, (rx, connection));
166        });
167
168        (opts, server_client.into(), tx)
169    }
170}
171
172#[async_trait(?Send)]
173impl RemoteConnection for MockRemoteConnection {
174    async fn kill(&self) -> Result<()> {
175        Ok(())
176    }
177
178    fn has_been_killed(&self) -> bool {
179        false
180    }
181
182    fn build_command(
183        &self,
184        program: Option<String>,
185        args: &[String],
186        env: &HashMap<String, String>,
187        _working_dir: Option<String>,
188        _port_forward: Option<(u16, String, u16)>,
189        _interactive: Interactive,
190    ) -> Result<CommandTemplate> {
191        let shell_program = program.unwrap_or_else(|| "sh".to_string());
192        let mut shell_args = Vec::new();
193        shell_args.push(shell_program);
194        shell_args.extend(args.iter().cloned());
195        Ok(CommandTemplate {
196            program: "mock".into(),
197            args: shell_args,
198            env: env.clone(),
199        })
200    }
201
202    fn build_forward_ports_command(
203        &self,
204        forwards: Vec<(u16, String, u16)>,
205    ) -> Result<CommandTemplate> {
206        Ok(CommandTemplate {
207            program: "mock".into(),
208            args: std::iter::once("-N".to_owned())
209                .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
210                    format!("{local_port}:{host}:{remote_port}")
211                }))
212                .collect(),
213            env: Default::default(),
214        })
215    }
216
217    fn upload_directory(
218        &self,
219        _src_path: PathBuf,
220        _dest_path: RemotePathBuf,
221        _cx: &App,
222    ) -> Task<Result<()>> {
223        Task::ready(Ok(()))
224    }
225
226    fn connection_options(&self) -> RemoteConnectionOptions {
227        RemoteConnectionOptions::Mock(self.options.clone())
228    }
229
230    fn simulate_disconnect(&self, cx: &AsyncApp) {
231        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
232        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
233        self.server_channel
234            .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
235    }
236
237    fn start_proxy(
238        &self,
239        _unique_identifier: String,
240        _reconnect: bool,
241        mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
242        mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
243        mut connection_activity_tx: Sender<()>,
244        _delegate: Arc<dyn RemoteClientDelegate>,
245        cx: &mut AsyncApp,
246    ) -> Task<Result<i32>> {
247        let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
248        let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
249
250        self.server_channel.reconnect(
251            server_incoming_rx,
252            server_outgoing_tx,
253            &self.server_cx.get(cx),
254        );
255
256        cx.background_spawn(async move {
257            loop {
258                select_biased! {
259                    server_to_client = server_outgoing_rx.next().fuse() => {
260                        let Some(server_to_client) = server_to_client else {
261                            return Ok(1)
262                        };
263                        connection_activity_tx.try_send(()).ok();
264                        client_incoming_tx.send(server_to_client).await.ok();
265                    }
266                    client_to_server = client_outgoing_rx.next().fuse() => {
267                        let Some(client_to_server) = client_to_server else {
268                            return Ok(1)
269                        };
270                        server_incoming_tx.send(client_to_server).await.ok();
271                    }
272                }
273            }
274        })
275    }
276
277    fn path_style(&self) -> PathStyle {
278        PathStyle::local()
279    }
280
281    fn shell(&self) -> String {
282        "sh".to_owned()
283    }
284
285    fn default_system_shell(&self) -> String {
286        "sh".to_owned()
287    }
288
289    fn has_wsl_interop(&self) -> bool {
290        false
291    }
292}
293
294/// Mock delegate for tests that don't need delegate functionality.
295pub struct MockDelegate;
296
297impl RemoteClientDelegate for MockDelegate {
298    fn ask_password(
299        &self,
300        _prompt: String,
301        _sender: futures::channel::oneshot::Sender<askpass::EncryptedPassword>,
302        _cx: &mut AsyncApp,
303    ) {
304        unreachable!("MockDelegate::ask_password should not be called in tests")
305    }
306
307    fn download_server_binary_locally(
308        &self,
309        _platform: crate::RemotePlatform,
310        _release_channel: release_channel::ReleaseChannel,
311        _version: Option<semver::Version>,
312        _cx: &mut AsyncApp,
313    ) -> Task<Result<PathBuf>> {
314        unreachable!("MockDelegate::download_server_binary_locally should not be called in tests")
315    }
316
317    fn get_download_url(
318        &self,
319        _platform: crate::RemotePlatform,
320        _release_channel: release_channel::ReleaseChannel,
321        _version: Option<semver::Version>,
322        _cx: &mut AsyncApp,
323    ) -> Task<Result<Option<String>>> {
324        unreachable!("MockDelegate::get_download_url should not be called in tests")
325    }
326
327    fn set_status(&self, _status: Option<&str>, _cx: &mut AsyncApp) {}
328}