mock.rs

  1//! Mock transport for testing remote connections.
  2//!
  3//! This module provides a mock implementation of the `RemoteConnection` trait
  4//! that allows testing remote editing functionality without actual SSH/WSL/Docker
  5//! connections.
  6//!
  7//! # Usage
  8//!
  9//! ```rust,ignore
 10//! use remote::{MockConnection, RemoteClient};
 11//!
 12//! #[gpui::test]
 13//! async fn test_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
 14//!     let (opts, server_session) = MockConnection::new(cx, server_cx);
 15//!
 16//!     // Create the headless project (server side)
 17//!     server_cx.update(HeadlessProject::init);
 18//!     let _headless = server_cx.new(|cx| {
 19//!         HeadlessProject::new(
 20//!             HeadlessAppState { session: server_session, /* ... */ },
 21//!             false,
 22//!             cx,
 23//!         )
 24//!     });
 25//!
 26//!     // Create the client using the helper
 27//!     let (client, server_client) = RemoteClient::new_mock(cx, server_cx).await;
 28//!     // ... test logic ...
 29//! }
 30//! ```
 31
 32use crate::remote_client::{
 33    ChannelClient, CommandTemplate, RemoteClientDelegate, RemoteConnection, RemoteConnectionOptions,
 34};
 35use anyhow::Result;
 36use async_trait::async_trait;
 37use collections::HashMap;
 38use futures::{
 39    FutureExt, SinkExt, StreamExt,
 40    channel::mpsc::{self, Sender},
 41    select_biased,
 42};
 43use gpui::{App, AppContext as _, AsyncApp, Global, Task, TestAppContext};
 44use rpc::{AnyProtoClient, proto::Envelope};
 45use std::{
 46    path::PathBuf,
 47    sync::{
 48        Arc,
 49        atomic::{AtomicU64, Ordering},
 50    },
 51};
 52use util::paths::{PathStyle, RemotePathBuf};
 53
 54/// Unique identifier for a mock connection.
 55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 56pub struct MockConnectionOptions {
 57    pub id: u64,
 58}
 59
 60/// A mock implementation of `RemoteConnection` for testing.
 61pub struct MockRemoteConnection {
 62    options: MockConnectionOptions,
 63    server_channel: Arc<ChannelClient>,
 64    server_cx: SendableCx,
 65}
 66
 67/// Wrapper to pass `AsyncApp` across thread boundaries in tests.
 68///
 69/// # Safety
 70///
 71/// This is safe because in test mode, GPUI is always single-threaded and so
 72/// having access to one async app means being on the same main thread.
 73pub(crate) struct SendableCx(AsyncApp);
 74
 75impl SendableCx {
 76    pub(crate) fn new(cx: &TestAppContext) -> Self {
 77        Self(cx.to_async())
 78    }
 79
 80    pub(crate) fn get(&self, _: &AsyncApp) -> AsyncApp {
 81        self.0.clone()
 82    }
 83}
 84
 85// SAFETY: In test mode, GPUI is always single-threaded, and SendableCx
 86// is only accessed from the main thread via the get() method which
 87// requires a valid AsyncApp reference.
 88unsafe impl Send for SendableCx {}
 89unsafe impl Sync for SendableCx {}
 90
 91/// Global registry that holds pre-created mock connections.
 92///
 93/// When `ConnectionPool::connect` is called with `MockConnectionOptions`,
 94/// it retrieves the connection from this registry.
 95#[derive(Default)]
 96pub struct MockConnectionRegistry {
 97    pending: HashMap<MockConnectionOptions, Arc<MockRemoteConnection>>,
 98}
 99
100impl Global for MockConnectionRegistry {}
101
102impl MockConnectionRegistry {
103    /// Called by `ConnectionPool::connect` to retrieve a pre-registered mock connection.
104    pub fn take(&mut self, opts: &MockConnectionOptions) -> Option<Arc<MockRemoteConnection>> {
105        self.pending.remove(opts)
106    }
107}
108
109/// Helper for creating mock connection pairs in tests.
110pub struct MockConnection;
111
112impl MockConnection {
113    /// Creates a new mock connection pair for testing.
114    ///
115    /// This function:
116    /// 1. Creates a unique `MockConnectionOptions` identifier
117    /// 2. Sets up the server-side channel (returned as `AnyProtoClient`)
118    /// 3. Creates a `MockRemoteConnection` and registers it in the global registry
119    /// 4. The connection will be retrieved from the registry when `ConnectionPool::connect` is called
120    ///
121    /// Returns:
122    /// - `MockConnectionOptions` to pass to `remote::connect()` or `RemoteClient` creation
123    /// - `AnyProtoClient` to pass to `HeadlessProject::new()` as the session
124    ///
125    /// # Arguments
126    /// - `client_cx`: The test context for the client side
127    /// - `server_cx`: The test context for the server/headless side
128    pub fn new(
129        client_cx: &mut TestAppContext,
130        server_cx: &mut TestAppContext,
131    ) -> (MockConnectionOptions, AnyProtoClient) {
132        static NEXT_ID: AtomicU64 = AtomicU64::new(0);
133        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
134        let opts = MockConnectionOptions { id };
135
136        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
137        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
138        let server_client = server_cx
139            .update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "mock-server", false));
140
141        let connection = Arc::new(MockRemoteConnection {
142            options: opts.clone(),
143            server_channel: server_client.clone(),
144            server_cx: SendableCx::new(server_cx),
145        });
146
147        client_cx.update(|cx| {
148            cx.default_global::<MockConnectionRegistry>()
149                .pending
150                .insert(opts.clone(), connection);
151        });
152
153        (opts, server_client.into())
154    }
155}
156
157#[async_trait(?Send)]
158impl RemoteConnection for MockRemoteConnection {
159    async fn kill(&self) -> Result<()> {
160        Ok(())
161    }
162
163    fn has_been_killed(&self) -> bool {
164        false
165    }
166
167    fn build_command(
168        &self,
169        program: Option<String>,
170        args: &[String],
171        env: &HashMap<String, String>,
172        _working_dir: Option<String>,
173        _port_forward: Option<(u16, String, u16)>,
174    ) -> Result<CommandTemplate> {
175        let shell_program = program.unwrap_or_else(|| "sh".to_string());
176        let mut shell_args = Vec::new();
177        shell_args.push(shell_program);
178        shell_args.extend(args.iter().cloned());
179        Ok(CommandTemplate {
180            program: "mock".into(),
181            args: shell_args,
182            env: env.clone(),
183        })
184    }
185
186    fn build_forward_ports_command(
187        &self,
188        forwards: Vec<(u16, String, u16)>,
189    ) -> Result<CommandTemplate> {
190        Ok(CommandTemplate {
191            program: "mock".into(),
192            args: std::iter::once("-N".to_owned())
193                .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
194                    format!("{local_port}:{host}:{remote_port}")
195                }))
196                .collect(),
197            env: Default::default(),
198        })
199    }
200
201    fn upload_directory(
202        &self,
203        _src_path: PathBuf,
204        _dest_path: RemotePathBuf,
205        _cx: &App,
206    ) -> Task<Result<()>> {
207        Task::ready(Ok(()))
208    }
209
210    fn connection_options(&self) -> RemoteConnectionOptions {
211        RemoteConnectionOptions::Mock(self.options.clone())
212    }
213
214    fn simulate_disconnect(&self, cx: &AsyncApp) {
215        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
216        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
217        self.server_channel
218            .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
219    }
220
221    fn start_proxy(
222        &self,
223        _unique_identifier: String,
224        _reconnect: bool,
225        mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
226        mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
227        mut connection_activity_tx: Sender<()>,
228        _delegate: Arc<dyn RemoteClientDelegate>,
229        cx: &mut AsyncApp,
230    ) -> Task<Result<i32>> {
231        let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
232        let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
233
234        self.server_channel.reconnect(
235            server_incoming_rx,
236            server_outgoing_tx,
237            &self.server_cx.get(cx),
238        );
239
240        cx.background_spawn(async move {
241            loop {
242                select_biased! {
243                    server_to_client = server_outgoing_rx.next().fuse() => {
244                        let Some(server_to_client) = server_to_client else {
245                            return Ok(1)
246                        };
247                        connection_activity_tx.try_send(()).ok();
248                        client_incoming_tx.send(server_to_client).await.ok();
249                    }
250                    client_to_server = client_outgoing_rx.next().fuse() => {
251                        let Some(client_to_server) = client_to_server else {
252                            return Ok(1)
253                        };
254                        server_incoming_tx.send(client_to_server).await.ok();
255                    }
256                }
257            }
258        })
259    }
260
261    fn path_style(&self) -> PathStyle {
262        PathStyle::local()
263    }
264
265    fn shell(&self) -> String {
266        "sh".to_owned()
267    }
268
269    fn default_system_shell(&self) -> String {
270        "sh".to_owned()
271    }
272
273    fn has_wsl_interop(&self) -> bool {
274        false
275    }
276}
277
278/// Mock delegate for tests that don't need delegate functionality.
279pub struct MockDelegate;
280
281impl RemoteClientDelegate for MockDelegate {
282    fn ask_password(
283        &self,
284        _prompt: String,
285        _sender: futures::channel::oneshot::Sender<askpass::EncryptedPassword>,
286        _cx: &mut AsyncApp,
287    ) {
288        unreachable!("MockDelegate::ask_password should not be called in tests")
289    }
290
291    fn download_server_binary_locally(
292        &self,
293        _platform: crate::RemotePlatform,
294        _release_channel: release_channel::ReleaseChannel,
295        _version: Option<semver::Version>,
296        _cx: &mut AsyncApp,
297    ) -> Task<Result<PathBuf>> {
298        unreachable!("MockDelegate::download_server_binary_locally should not be called in tests")
299    }
300
301    fn get_download_url(
302        &self,
303        _platform: crate::RemotePlatform,
304        _release_channel: release_channel::ReleaseChannel,
305        _version: Option<semver::Version>,
306        _cx: &mut AsyncApp,
307    ) -> Task<Result<Option<String>>> {
308        unreachable!("MockDelegate::get_download_url should not be called in tests")
309    }
310
311    fn set_status(&self, _status: Option<&str>, _cx: &mut AsyncApp) {}
312}