mock.rs

  1//! Mock transport for testing remote connections.
  2//!
  3//! This module provides a mock implementation of the `RemoteConnection` trait
  4//! that allows testing remote editing functionality without actual SSH/WSL/Docker
  5//! connections.
  6//!
  7//! # Usage
  8//!
  9//! ```rust,ignore
 10//! use remote::{MockConnection, RemoteClient};
 11//!
 12//! #[gpui::test]
 13//! async fn test_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
 14//!     let (opts, server_session) = MockConnection::new(cx, server_cx);
 15//!
 16//!     // Create the headless project (server side)
 17//!     server_cx.update(HeadlessProject::init);
 18//!     let _headless = server_cx.new(|cx| {
 19//!         HeadlessProject::new(
 20//!             HeadlessAppState { session: server_session, /* ... */ },
 21//!             false,
 22//!             cx,
 23//!         )
 24//!     });
 25//!
 26//!     // Create the client using the helper
 27//!     let (client, server_client) = RemoteClient::new_mock(cx, server_cx).await;
 28//!     // ... test logic ...
 29//! }
 30//! ```
 31
 32use crate::remote_client::{
 33    ChannelClient, CommandTemplate, RemoteClientDelegate, RemoteConnection, RemoteConnectionOptions,
 34};
 35use anyhow::Result;
 36use async_trait::async_trait;
 37use collections::HashMap;
 38use futures::{
 39    FutureExt, SinkExt, StreamExt,
 40    channel::{
 41        mpsc::{self, Sender},
 42        oneshot,
 43    },
 44    select_biased,
 45};
 46use gpui::{App, AppContext as _, AsyncApp, Global, Task, TestAppContext};
 47use rpc::{AnyProtoClient, proto::Envelope};
 48use std::{
 49    path::PathBuf,
 50    sync::{
 51        Arc,
 52        atomic::{AtomicU64, Ordering},
 53    },
 54};
 55use util::paths::{PathStyle, RemotePathBuf};
 56
 57/// Unique identifier for a mock connection.
 58#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 59pub struct MockConnectionOptions {
 60    pub id: u64,
 61}
 62
 63/// A mock implementation of `RemoteConnection` for testing.
 64pub struct MockRemoteConnection {
 65    options: MockConnectionOptions,
 66    server_channel: Arc<ChannelClient>,
 67    server_cx: SendableCx,
 68}
 69
 70/// Wrapper to pass `AsyncApp` across thread boundaries in tests.
 71///
 72/// # Safety
 73///
 74/// This is safe because in test mode, GPUI is always single-threaded and so
 75/// having access to one async app means being on the same main thread.
 76pub(crate) struct SendableCx(AsyncApp);
 77
 78impl SendableCx {
 79    pub(crate) fn new(cx: &TestAppContext) -> Self {
 80        Self(cx.to_async())
 81    }
 82
 83    pub(crate) fn get(&self, _: &AsyncApp) -> AsyncApp {
 84        self.0.clone()
 85    }
 86}
 87
 88// SAFETY: In test mode, GPUI is always single-threaded, and SendableCx
 89// is only accessed from the main thread via the get() method which
 90// requires a valid AsyncApp reference.
 91unsafe impl Send for SendableCx {}
 92unsafe impl Sync for SendableCx {}
 93
 94/// Global registry that holds pre-created mock connections.
 95///
 96/// When `ConnectionPool::connect` is called with `MockConnectionOptions`,
 97/// it retrieves the connection from this registry.
 98#[derive(Default)]
 99pub struct MockConnectionRegistry {
100    pending: HashMap<MockConnectionOptions, (oneshot::Receiver<()>, Arc<MockRemoteConnection>)>,
101}
102
103impl Global for MockConnectionRegistry {}
104
105impl MockConnectionRegistry {
106    /// Called by `ConnectionPool::connect` to retrieve a pre-registered mock connection.
107    pub fn take(
108        &mut self,
109        opts: &MockConnectionOptions,
110    ) -> Option<impl Future<Output = Arc<MockRemoteConnection>> + use<>> {
111        let (guard, con) = self.pending.remove(opts)?;
112        Some(async move {
113            _ = guard.await;
114            con
115        })
116    }
117}
118
119/// Helper for creating mock connection pairs in tests.
120pub struct MockConnection;
121
122pub type ConnectGuard = oneshot::Sender<()>;
123
124impl MockConnection {
125    /// Creates a new mock connection pair for testing.
126    ///
127    /// This function:
128    /// 1. Creates a unique `MockConnectionOptions` identifier
129    /// 2. Sets up the server-side channel (returned as `AnyProtoClient`)
130    /// 3. Creates a `MockRemoteConnection` and registers it in the global registry
131    /// 4. The connection will be retrieved from the registry when `ConnectionPool::connect` is called
132    ///
133    /// Returns:
134    /// - `MockConnectionOptions` to pass to `remote::connect()` or `RemoteClient` creation
135    /// - `AnyProtoClient` to pass to `HeadlessProject::new()` as the session
136    ///
137    /// # Arguments
138    /// - `client_cx`: The test context for the client side
139    /// - `server_cx`: The test context for the server/headless side
140    pub(crate) fn new(
141        client_cx: &mut TestAppContext,
142        server_cx: &mut TestAppContext,
143    ) -> (MockConnectionOptions, AnyProtoClient, ConnectGuard) {
144        static NEXT_ID: AtomicU64 = AtomicU64::new(0);
145        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
146        let opts = MockConnectionOptions { id };
147
148        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
149        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
150        let server_client = server_cx
151            .update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "mock-server", false));
152
153        let connection = Arc::new(MockRemoteConnection {
154            options: opts.clone(),
155            server_channel: server_client.clone(),
156            server_cx: SendableCx::new(server_cx),
157        });
158
159        let (tx, rx) = oneshot::channel();
160
161        client_cx.update(|cx| {
162            cx.default_global::<MockConnectionRegistry>()
163                .pending
164                .insert(opts.clone(), (rx, connection));
165        });
166
167        (opts, server_client.into(), tx)
168    }
169}
170
171#[async_trait(?Send)]
172impl RemoteConnection for MockRemoteConnection {
173    async fn kill(&self) -> Result<()> {
174        Ok(())
175    }
176
177    fn has_been_killed(&self) -> bool {
178        false
179    }
180
181    fn build_command(
182        &self,
183        program: Option<String>,
184        args: &[String],
185        env: &HashMap<String, String>,
186        _working_dir: Option<String>,
187        _port_forward: Option<(u16, String, u16)>,
188    ) -> Result<CommandTemplate> {
189        let shell_program = program.unwrap_or_else(|| "sh".to_string());
190        let mut shell_args = Vec::new();
191        shell_args.push(shell_program);
192        shell_args.extend(args.iter().cloned());
193        Ok(CommandTemplate {
194            program: "mock".into(),
195            args: shell_args,
196            env: env.clone(),
197        })
198    }
199
200    fn build_forward_ports_command(
201        &self,
202        forwards: Vec<(u16, String, u16)>,
203    ) -> Result<CommandTemplate> {
204        Ok(CommandTemplate {
205            program: "mock".into(),
206            args: std::iter::once("-N".to_owned())
207                .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
208                    format!("{local_port}:{host}:{remote_port}")
209                }))
210                .collect(),
211            env: Default::default(),
212        })
213    }
214
215    fn upload_directory(
216        &self,
217        _src_path: PathBuf,
218        _dest_path: RemotePathBuf,
219        _cx: &App,
220    ) -> Task<Result<()>> {
221        Task::ready(Ok(()))
222    }
223
224    fn connection_options(&self) -> RemoteConnectionOptions {
225        RemoteConnectionOptions::Mock(self.options.clone())
226    }
227
228    fn simulate_disconnect(&self, cx: &AsyncApp) {
229        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
230        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
231        self.server_channel
232            .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
233    }
234
235    fn start_proxy(
236        &self,
237        _unique_identifier: String,
238        _reconnect: bool,
239        mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
240        mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
241        mut connection_activity_tx: Sender<()>,
242        _delegate: Arc<dyn RemoteClientDelegate>,
243        cx: &mut AsyncApp,
244    ) -> Task<Result<i32>> {
245        let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
246        let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
247
248        self.server_channel.reconnect(
249            server_incoming_rx,
250            server_outgoing_tx,
251            &self.server_cx.get(cx),
252        );
253
254        cx.background_spawn(async move {
255            loop {
256                select_biased! {
257                    server_to_client = server_outgoing_rx.next().fuse() => {
258                        let Some(server_to_client) = server_to_client else {
259                            return Ok(1)
260                        };
261                        connection_activity_tx.try_send(()).ok();
262                        client_incoming_tx.send(server_to_client).await.ok();
263                    }
264                    client_to_server = client_outgoing_rx.next().fuse() => {
265                        let Some(client_to_server) = client_to_server else {
266                            return Ok(1)
267                        };
268                        server_incoming_tx.send(client_to_server).await.ok();
269                    }
270                }
271            }
272        })
273    }
274
275    fn path_style(&self) -> PathStyle {
276        PathStyle::local()
277    }
278
279    fn shell(&self) -> String {
280        "sh".to_owned()
281    }
282
283    fn default_system_shell(&self) -> String {
284        "sh".to_owned()
285    }
286
287    fn has_wsl_interop(&self) -> bool {
288        false
289    }
290}
291
292/// Mock delegate for tests that don't need delegate functionality.
293pub struct MockDelegate;
294
295impl RemoteClientDelegate for MockDelegate {
296    fn ask_password(
297        &self,
298        _prompt: String,
299        _sender: futures::channel::oneshot::Sender<askpass::EncryptedPassword>,
300        _cx: &mut AsyncApp,
301    ) {
302        unreachable!("MockDelegate::ask_password should not be called in tests")
303    }
304
305    fn download_server_binary_locally(
306        &self,
307        _platform: crate::RemotePlatform,
308        _release_channel: release_channel::ReleaseChannel,
309        _version: Option<semver::Version>,
310        _cx: &mut AsyncApp,
311    ) -> Task<Result<PathBuf>> {
312        unreachable!("MockDelegate::download_server_binary_locally should not be called in tests")
313    }
314
315    fn get_download_url(
316        &self,
317        _platform: crate::RemotePlatform,
318        _release_channel: release_channel::ReleaseChannel,
319        _version: Option<semver::Version>,
320        _cx: &mut AsyncApp,
321    ) -> Task<Result<Option<String>>> {
322        unreachable!("MockDelegate::get_download_url should not be called in tests")
323    }
324
325    fn set_status(&self, _status: Option<&str>, _cx: &mut AsyncApp) {}
326}