mock.rs

  1//! Mock transport for testing remote connections.
  2//!
  3//! This module provides a mock implementation of the `RemoteConnection` trait
  4//! that allows testing remote editing functionality without actual SSH/WSL/Docker
  5//! connections.
  6//!
  7//! # Usage
  8//!
  9//! ```rust,ignore
 10//! use remote::{MockConnection, RemoteClient};
 11//!
 12//! #[gpui::test]
 13//! async fn test_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
 14//!     let (opts, server_session) = MockConnection::new(cx, server_cx);
 15//!
 16//!     // Create the headless project (server side)
 17//!     server_cx.update(HeadlessProject::init);
 18//!     let _headless = server_cx.new(|cx| {
 19//!         HeadlessProject::new(
 20//!             HeadlessAppState { session: server_session, /* ... */ },
 21//!             false,
 22//!             cx,
 23//!         )
 24//!     });
 25//!
 26//!     // Create the client using the helper
 27//!     let (client, server_client) = RemoteClient::new_mock(cx, server_cx).await;
 28//!     // ... test logic ...
 29//! }
 30//! ```
 31
 32use crate::remote_client::{
 33    ChannelClient, CommandTemplate, Interactive, RemoteClientDelegate, RemoteConnection,
 34    RemoteConnectionOptions,
 35};
 36use anyhow::Result;
 37use async_trait::async_trait;
 38use collections::HashMap;
 39use futures::{
 40    FutureExt, SinkExt, StreamExt,
 41    channel::{
 42        mpsc::{self, Sender},
 43        oneshot,
 44    },
 45    select_biased,
 46};
 47use gpui::{App, AppContext as _, AsyncApp, Global, Task, TestAppContext};
 48use rpc::{AnyProtoClient, proto::Envelope};
 49use std::{
 50    path::PathBuf,
 51    sync::{
 52        Arc,
 53        atomic::{AtomicU64, Ordering},
 54    },
 55};
 56use util::paths::{PathStyle, RemotePathBuf};
 57
 58/// Unique identifier for a mock connection.
 59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 60pub struct MockConnectionOptions {
 61    pub id: u64,
 62}
 63
 64/// A mock implementation of `RemoteConnection` for testing.
 65pub struct MockRemoteConnection {
 66    options: MockConnectionOptions,
 67    server_channel: Arc<ChannelClient>,
 68    server_cx: SendableCx,
 69}
 70
 71/// Wrapper to pass `AsyncApp` across thread boundaries in tests.
 72///
 73/// # Safety
 74///
 75/// This is safe because in test mode, GPUI is always single-threaded and so
 76/// having access to one async app means being on the same main thread.
 77pub(crate) struct SendableCx(AsyncApp);
 78
 79impl SendableCx {
 80    pub(crate) fn new(cx: &TestAppContext) -> Self {
 81        Self(cx.to_async())
 82    }
 83
 84    pub(crate) fn get(&self, _: &AsyncApp) -> AsyncApp {
 85        self.0.clone()
 86    }
 87}
 88
 89// SAFETY: In test mode, GPUI is always single-threaded, and SendableCx
 90// is only accessed from the main thread via the get() method which
 91// requires a valid AsyncApp reference.
 92unsafe impl Send for SendableCx {}
 93unsafe impl Sync for SendableCx {}
 94
 95/// Global registry that holds pre-created mock connections.
 96///
 97/// When `ConnectionPool::connect` is called with `MockConnectionOptions`,
 98/// it retrieves the connection from this registry.
 99#[derive(Default)]
100pub struct MockConnectionRegistry {
101    pending: HashMap<u64, (oneshot::Receiver<()>, Arc<MockRemoteConnection>)>,
102}
103
104impl Global for MockConnectionRegistry {}
105
106impl MockConnectionRegistry {
107    /// Called by `ConnectionPool::connect` to retrieve a pre-registered mock connection.
108    pub fn take(
109        &mut self,
110        opts: &MockConnectionOptions,
111    ) -> Option<impl Future<Output = Arc<MockRemoteConnection>> + use<>> {
112        let (guard, con) = self.pending.remove(&opts.id)?;
113        Some(async move {
114            _ = guard.await;
115            con
116        })
117    }
118}
119
120/// Helper for creating mock connection pairs in tests.
121pub struct MockConnection;
122
123pub type ConnectGuard = oneshot::Sender<()>;
124
125impl MockConnection {
126    /// Creates a new mock connection pair for testing.
127    ///
128    /// This function:
129    /// 1. Creates a unique `MockConnectionOptions` identifier
130    /// 2. Sets up the server-side channel (returned as `AnyProtoClient`)
131    /// 3. Creates a `MockRemoteConnection` and registers it in the global registry
132    /// 4. The connection will be retrieved from the registry when `ConnectionPool::connect` is called
133    ///
134    /// Returns:
135    /// - `MockConnectionOptions` to pass to `remote::connect()` or `RemoteClient` creation
136    /// - `AnyProtoClient` to pass to `HeadlessProject::new()` as the session
137    ///
138    /// # Arguments
139    /// - `client_cx`: The test context for the client side
140    /// - `server_cx`: The test context for the server/headless side
141    pub(crate) fn new(
142        client_cx: &mut TestAppContext,
143        server_cx: &mut TestAppContext,
144    ) -> (MockConnectionOptions, AnyProtoClient, ConnectGuard) {
145        static NEXT_ID: AtomicU64 = AtomicU64::new(0);
146        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
147        let opts = MockConnectionOptions { id };
148        let (server_client, connect_guard) =
149            Self::new_with_opts(opts.clone(), client_cx, server_cx);
150        (opts, server_client, connect_guard)
151    }
152
153    /// Creates a mock connection pair for existing `MockConnectionOptions`.
154    ///
155    /// This is useful when simulating reconnection: after a connection is torn
156    /// down, register a new mock server under the same options so the next
157    /// `ConnectionPool::connect` call finds it.
158    pub(crate) fn new_with_opts(
159        opts: MockConnectionOptions,
160        client_cx: &mut TestAppContext,
161        server_cx: &mut TestAppContext,
162    ) -> (AnyProtoClient, ConnectGuard) {
163        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
164        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
165        let server_client = server_cx
166            .update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "mock-server", false));
167
168        let connection = Arc::new(MockRemoteConnection {
169            options: opts.clone(),
170            server_channel: server_client.clone(),
171            server_cx: SendableCx::new(server_cx),
172        });
173
174        let (tx, rx) = oneshot::channel();
175
176        client_cx.update(|cx| {
177            cx.default_global::<MockConnectionRegistry>()
178                .pending
179                .insert(opts.id, (rx, connection));
180        });
181
182        (server_client.into(), tx)
183    }
184}
185
186#[async_trait(?Send)]
187impl RemoteConnection for MockRemoteConnection {
188    async fn kill(&self) -> Result<()> {
189        Ok(())
190    }
191
192    fn has_been_killed(&self) -> bool {
193        false
194    }
195
196    fn build_command(
197        &self,
198        program: Option<String>,
199        args: &[String],
200        env: &HashMap<String, String>,
201        _working_dir: Option<String>,
202        _port_forward: Option<(u16, String, u16)>,
203        _interactive: Interactive,
204    ) -> Result<CommandTemplate> {
205        let shell_program = program.unwrap_or_else(|| "sh".to_string());
206        let mut shell_args = Vec::new();
207        shell_args.push(shell_program);
208        shell_args.extend(args.iter().cloned());
209        Ok(CommandTemplate {
210            program: "mock".into(),
211            args: shell_args,
212            env: env.clone(),
213        })
214    }
215
216    fn build_forward_ports_command(
217        &self,
218        forwards: Vec<(u16, String, u16)>,
219    ) -> Result<CommandTemplate> {
220        Ok(CommandTemplate {
221            program: "mock".into(),
222            args: std::iter::once("-N".to_owned())
223                .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
224                    format!("{local_port}:{host}:{remote_port}")
225                }))
226                .collect(),
227            env: Default::default(),
228        })
229    }
230
231    fn upload_directory(
232        &self,
233        _src_path: PathBuf,
234        _dest_path: RemotePathBuf,
235        _cx: &App,
236    ) -> Task<Result<()>> {
237        Task::ready(Ok(()))
238    }
239
240    fn connection_options(&self) -> RemoteConnectionOptions {
241        RemoteConnectionOptions::Mock(self.options.clone())
242    }
243
244    fn simulate_disconnect(&self, cx: &AsyncApp) {
245        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
246        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
247        self.server_channel
248            .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
249    }
250
251    fn start_proxy(
252        &self,
253        _unique_identifier: String,
254        _reconnect: bool,
255        mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
256        mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
257        mut connection_activity_tx: Sender<()>,
258        _delegate: Arc<dyn RemoteClientDelegate>,
259        cx: &mut AsyncApp,
260    ) -> Task<Result<i32>> {
261        let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
262        let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
263
264        self.server_channel.reconnect(
265            server_incoming_rx,
266            server_outgoing_tx,
267            &self.server_cx.get(cx),
268        );
269
270        cx.background_spawn(async move {
271            loop {
272                select_biased! {
273                    server_to_client = server_outgoing_rx.next().fuse() => {
274                        let Some(server_to_client) = server_to_client else {
275                            return Ok(1)
276                        };
277                        connection_activity_tx.try_send(()).ok();
278                        client_incoming_tx.send(server_to_client).await.ok();
279                    }
280                    client_to_server = client_outgoing_rx.next().fuse() => {
281                        let Some(client_to_server) = client_to_server else {
282                            return Ok(1)
283                        };
284                        server_incoming_tx.send(client_to_server).await.ok();
285                    }
286                }
287            }
288        })
289    }
290
291    fn path_style(&self) -> PathStyle {
292        PathStyle::local()
293    }
294
295    fn shell(&self) -> String {
296        "sh".to_owned()
297    }
298
299    fn default_system_shell(&self) -> String {
300        "sh".to_owned()
301    }
302
303    fn has_wsl_interop(&self) -> bool {
304        false
305    }
306}
307
308/// Mock delegate for tests that don't need delegate functionality.
309pub struct MockDelegate;
310
311impl RemoteClientDelegate for MockDelegate {
312    fn ask_password(
313        &self,
314        _prompt: String,
315        _sender: futures::channel::oneshot::Sender<askpass::EncryptedPassword>,
316        _cx: &mut AsyncApp,
317    ) {
318        unreachable!("MockDelegate::ask_password should not be called in tests")
319    }
320
321    fn download_server_binary_locally(
322        &self,
323        _platform: crate::RemotePlatform,
324        _release_channel: release_channel::ReleaseChannel,
325        _version: Option<semver::Version>,
326        _cx: &mut AsyncApp,
327    ) -> Task<Result<PathBuf>> {
328        unreachable!("MockDelegate::download_server_binary_locally should not be called in tests")
329    }
330
331    fn get_download_url(
332        &self,
333        _platform: crate::RemotePlatform,
334        _release_channel: release_channel::ReleaseChannel,
335        _version: Option<semver::Version>,
336        _cx: &mut AsyncApp,
337    ) -> Task<Result<Option<String>>> {
338        unreachable!("MockDelegate::get_download_url should not be called in tests")
339    }
340
341    fn set_status(&self, _status: Option<&str>, _cx: &mut AsyncApp) {}
342}