ssh remoting: Restore SSH projects when reopening Zed (#19188)

Thorsten Ball and Bennet created

Release Notes:

- N/A

---------

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/workspace/src/persistence.rs       | 163 ++++++++++++++++++++++--
crates/workspace/src/persistence/model.rs |  11 +
crates/workspace/src/workspace.rs         |   4 
crates/zed/src/main.rs                    |  51 +++++-
crates/zed/src/zed/open_listener.rs       |  91 ++++++++-----
5 files changed, 249 insertions(+), 71 deletions(-)

Detailed changes

crates/workspace/src/persistence.rs 🔗

@@ -732,9 +732,11 @@ impl WorkspaceDb {
                                 bottom_dock_visible,
                                 bottom_dock_active_panel,
                                 bottom_dock_zoom,
+                                session_id,
+                                window_id,
                                 timestamp
                             )
-                            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, CURRENT_TIMESTAMP)
+                            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, CURRENT_TIMESTAMP)
                             ON CONFLICT DO
                             UPDATE SET
                                 ssh_project_id = ?2,
@@ -747,11 +749,15 @@ impl WorkspaceDb {
                                 bottom_dock_visible = ?9,
                                 bottom_dock_active_panel = ?10,
                                 bottom_dock_zoom = ?11,
+                                session_id = ?12,
+                                window_id = ?13,
                                 timestamp = CURRENT_TIMESTAMP
                         ))?((
                             workspace.id,
                             ssh_project.id.0,
                             workspace.docks,
+                            workspace.session_id,
+                            workspace.window_id
                         ))
                         .context("Updating workspace")?;
                     }
@@ -827,8 +833,8 @@ impl WorkspaceDb {
     }
 
     query! {
-        fn session_workspaces(session_id: String) -> Result<Vec<(LocalPaths, Option<u64>)>> {
-            SELECT local_paths, window_id
+        fn session_workspaces(session_id: String) -> Result<Vec<(LocalPaths, Option<u64>, Option<u64>)>> {
+            SELECT local_paths, window_id, ssh_project_id
             FROM workspaces
             WHERE session_id = ?1 AND dev_server_project_id IS NULL
             ORDER BY timestamp DESC
@@ -849,6 +855,14 @@ impl WorkspaceDb {
         }
     }
 
+    query! {
+        fn ssh_project(id: u64) -> Result<SerializedSshProject> {
+            SELECT id, host, port, paths, user
+            FROM ssh_projects
+            WHERE id = ?
+        }
+    }
+
     pub(crate) fn last_window(
         &self,
     ) -> anyhow::Result<(Option<Uuid>, Option<SerializedWindowBounds>)> {
@@ -937,18 +951,13 @@ impl WorkspaceDb {
         Ok(result)
     }
 
-    pub async fn last_workspace(&self) -> Result<Option<LocalPaths>> {
+    pub async fn last_workspace(&self) -> Result<Option<SerializedWorkspaceLocation>> {
         Ok(self
             .recent_workspaces_on_disk()
             .await?
             .into_iter()
-            .filter_map(|(_, location)| match location {
-                SerializedWorkspaceLocation::Local(local_paths, _) => Some(local_paths),
-                // Do not automatically reopen Dev Server and SSH workspaces
-                SerializedWorkspaceLocation::DevServer(_) => None,
-                SerializedWorkspaceLocation::Ssh(_) => None,
-            })
-            .next())
+            .next()
+            .map(|(_, location)| location))
     }
 
     // Returns the locations of the workspaces that were still opened when the last
@@ -959,13 +968,20 @@ impl WorkspaceDb {
         &self,
         last_session_id: &str,
         last_session_window_stack: Option<Vec<WindowId>>,
-    ) -> Result<Vec<LocalPaths>> {
+    ) -> Result<Vec<SerializedWorkspaceLocation>> {
         let mut workspaces = Vec::new();
 
-        for (location, window_id) in self.session_workspaces(last_session_id.to_owned())? {
-            if location.paths().iter().all(|path| path.exists())
+        for (location, window_id, ssh_project_id) in
+            self.session_workspaces(last_session_id.to_owned())?
+        {
+            if let Some(ssh_project_id) = ssh_project_id {
+                let location = SerializedWorkspaceLocation::Ssh(self.ssh_project(ssh_project_id)?);
+                workspaces.push((location, window_id.map(WindowId::from)));
+            } else if location.paths().iter().all(|path| path.exists())
                 && location.paths().iter().any(|path| path.is_dir())
             {
+                let location =
+                    SerializedWorkspaceLocation::from_local_paths(location.paths().iter());
                 workspaces.push((location, window_id.map(WindowId::from)));
             }
         }
@@ -1570,10 +1586,28 @@ mod tests {
             window_id: None,
         };
 
+        let ssh_project = db
+            .get_or_create_ssh_project("my-host".to_string(), Some(1234), vec![], None)
+            .await
+            .unwrap();
+
+        let workspace_5 = SerializedWorkspace {
+            id: WorkspaceId(5),
+            location: SerializedWorkspaceLocation::Ssh(ssh_project.clone()),
+            center_group: Default::default(),
+            window_bounds: Default::default(),
+            display: Default::default(),
+            docks: Default::default(),
+            centered_layout: false,
+            session_id: Some("session-id-2".to_owned()),
+            window_id: Some(50),
+        };
+
         db.save_workspace(workspace_1.clone()).await;
         db.save_workspace(workspace_2.clone()).await;
         db.save_workspace(workspace_3.clone()).await;
         db.save_workspace(workspace_4.clone()).await;
+        db.save_workspace(workspace_5.clone()).await;
 
         let locations = db.session_workspaces("session-id-1".to_owned()).unwrap();
         assert_eq!(locations.len(), 2);
@@ -1583,9 +1617,13 @@ mod tests {
         assert_eq!(locations[1].1, Some(20));
 
         let locations = db.session_workspaces("session-id-2".to_owned()).unwrap();
-        assert_eq!(locations.len(), 1);
+        assert_eq!(locations.len(), 2);
         assert_eq!(locations[0].0, LocalPaths::new(["/tmp3"]));
         assert_eq!(locations[0].1, Some(30));
+        let empty_paths: Vec<&str> = Vec::new();
+        assert_eq!(locations[1].0, LocalPaths::new(empty_paths.iter()));
+        assert_eq!(locations[1].1, Some(50));
+        assert_eq!(locations[1].2, Some(ssh_project.id.0));
     }
 
     fn default_workspace<P: AsRef<Path>>(
@@ -1650,10 +1688,97 @@ mod tests {
             .last_session_workspace_locations("one-session", stack)
             .unwrap();
         assert_eq!(have.len(), 4);
-        assert_eq!(have[0], LocalPaths::new([dir4.path().to_str().unwrap()]));
-        assert_eq!(have[1], LocalPaths::new([dir3.path().to_str().unwrap()]));
-        assert_eq!(have[2], LocalPaths::new([dir2.path().to_str().unwrap()]));
-        assert_eq!(have[3], LocalPaths::new([dir1.path().to_str().unwrap()]));
+        assert_eq!(
+            have[0],
+            SerializedWorkspaceLocation::from_local_paths(&[dir4.path().to_str().unwrap()])
+        );
+        assert_eq!(
+            have[1],
+            SerializedWorkspaceLocation::from_local_paths([dir3.path().to_str().unwrap()])
+        );
+        assert_eq!(
+            have[2],
+            SerializedWorkspaceLocation::from_local_paths([dir2.path().to_str().unwrap()])
+        );
+        assert_eq!(
+            have[3],
+            SerializedWorkspaceLocation::from_local_paths([dir1.path().to_str().unwrap()])
+        );
+    }
+
+    #[gpui::test]
+    async fn test_last_session_workspace_locations_ssh_projects() {
+        let db = WorkspaceDb(
+            open_test_db("test_serializing_workspaces_last_session_workspaces_ssh_projects").await,
+        );
+
+        let ssh_projects = [
+            ("host-1", "my-user-1"),
+            ("host-2", "my-user-2"),
+            ("host-3", "my-user-3"),
+            ("host-4", "my-user-4"),
+        ]
+        .into_iter()
+        .map(|(host, user)| async {
+            db.get_or_create_ssh_project(host.to_string(), None, vec![], Some(user.to_string()))
+                .await
+                .unwrap()
+        })
+        .collect::<Vec<_>>();
+
+        let ssh_projects = futures::future::join_all(ssh_projects).await;
+
+        let workspaces = [
+            (1, ssh_projects[0].clone(), 9),
+            (2, ssh_projects[1].clone(), 5),
+            (3, ssh_projects[2].clone(), 8),
+            (4, ssh_projects[3].clone(), 2),
+        ]
+        .into_iter()
+        .map(|(id, ssh_project, window_id)| SerializedWorkspace {
+            id: WorkspaceId(id),
+            location: SerializedWorkspaceLocation::Ssh(ssh_project),
+            center_group: Default::default(),
+            window_bounds: Default::default(),
+            display: Default::default(),
+            docks: Default::default(),
+            centered_layout: false,
+            session_id: Some("one-session".to_owned()),
+            window_id: Some(window_id),
+        })
+        .collect::<Vec<_>>();
+
+        for workspace in workspaces.iter() {
+            db.save_workspace(workspace.clone()).await;
+        }
+
+        let stack = Some(Vec::from([
+            WindowId::from(2), // Top
+            WindowId::from(8),
+            WindowId::from(5),
+            WindowId::from(9), // Bottom
+        ]));
+
+        let have = db
+            .last_session_workspace_locations("one-session", stack)
+            .unwrap();
+        assert_eq!(have.len(), 4);
+        assert_eq!(
+            have[0],
+            SerializedWorkspaceLocation::Ssh(ssh_projects[3].clone())
+        );
+        assert_eq!(
+            have[1],
+            SerializedWorkspaceLocation::Ssh(ssh_projects[2].clone())
+        );
+        assert_eq!(
+            have[2],
+            SerializedWorkspaceLocation::Ssh(ssh_projects[1].clone())
+        );
+        assert_eq!(
+            have[3],
+            SerializedWorkspaceLocation::Ssh(ssh_projects[0].clone())
+        );
     }
 
     #[gpui::test]

crates/workspace/src/persistence/model.rs 🔗

@@ -11,7 +11,7 @@ use db::sqlez::{
 };
 use gpui::{AsyncWindowContext, Model, View, WeakView};
 use project::Project;
-use remote::ssh_session::SshProjectId;
+use remote::{ssh_session::SshProjectId, SshConnectionOptions};
 use serde::{Deserialize, Serialize};
 use std::{
     path::{Path, PathBuf},
@@ -50,6 +50,15 @@ impl SerializedSshProject {
             })
             .collect()
     }
+
+    pub fn connection_options(&self) -> SshConnectionOptions {
+        SshConnectionOptions {
+            host: self.host.clone(),
+            username: self.user.clone(),
+            port: self.port,
+            password: None,
+        }
+    }
 }
 
 impl StaticColumnCount for SerializedSshProject {

crates/workspace/src/workspace.rs 🔗

@@ -5046,14 +5046,14 @@ pub fn activate_workspace_for_project(
     None
 }
 
-pub async fn last_opened_workspace_paths() -> Option<LocalPaths> {
+pub async fn last_opened_workspace_location() -> Option<SerializedWorkspaceLocation> {
     DB.last_workspace().await.log_err().flatten()
 }
 
 pub fn last_session_workspace_locations(
     last_session_id: &str,
     last_session_window_stack: Option<Vec<WindowId>>,
-) -> Option<Vec<LocalPaths>> {
+) -> Option<Vec<SerializedWorkspaceLocation>> {
     DB.last_session_workspace_locations(last_session_id, last_session_window_stack)
         .log_err()
 }

crates/zed/src/main.rs 🔗

@@ -26,6 +26,7 @@ use gpui::{
 use http_client::{read_proxy_from_env, Uri};
 use language::LanguageRegistry;
 use log::LevelFilter;
+use remote::SshConnectionOptions;
 use reqwest_client::ReqwestClient;
 
 use assets::Assets;
@@ -55,7 +56,7 @@ use uuid::Uuid;
 use welcome::{show_welcome_view, BaseKeymap, FIRST_OPEN};
 use workspace::{
     notifications::{simple_message_notification::MessageNotification, NotificationId},
-    AppState, WorkspaceSettings, WorkspaceStore,
+    AppState, SerializedWorkspaceLocation, WorkspaceSettings, WorkspaceStore,
 };
 use zed::{
     app_menus, build_window_options, derive_paths_with_position, handle_cli_connection,
@@ -868,15 +869,41 @@ async fn restore_or_create_workspace(
 ) -> Result<()> {
     if let Some(locations) = restorable_workspace_locations(cx, &app_state).await {
         for location in locations {
-            cx.update(|cx| {
-                workspace::open_paths(
-                    location.paths().as_ref(),
-                    app_state.clone(),
-                    workspace::OpenOptions::default(),
-                    cx,
-                )
-            })?
-            .await?;
+            match location {
+                SerializedWorkspaceLocation::Local(location, _) => {
+                    let task = cx.update(|cx| {
+                        workspace::open_paths(
+                            location.paths().as_ref(),
+                            app_state.clone(),
+                            workspace::OpenOptions::default(),
+                            cx,
+                        )
+                    })?;
+                    task.await?;
+                }
+                SerializedWorkspaceLocation::Ssh(ssh_project) => {
+                    let connection_options = SshConnectionOptions {
+                        host: ssh_project.host.clone(),
+                        username: ssh_project.user.clone(),
+                        port: ssh_project.port,
+                        password: None,
+                    };
+                    let app_state = app_state.clone();
+                    cx.spawn(move |mut cx| async move {
+                        recent_projects::open_ssh_project(
+                            connection_options,
+                            ssh_project.paths.into_iter().map(PathBuf::from).collect(),
+                            app_state,
+                            workspace::OpenOptions::default(),
+                            &mut cx,
+                        )
+                        .await
+                        .log_err();
+                    })
+                    .detach();
+                }
+                SerializedWorkspaceLocation::DevServer(_) => {}
+            }
         }
     } else if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) {
         cx.update(|cx| show_welcome_view(app_state, cx))?.await?;
@@ -895,7 +922,7 @@ async fn restore_or_create_workspace(
 pub(crate) async fn restorable_workspace_locations(
     cx: &mut AsyncAppContext,
     app_state: &Arc<AppState>,
-) -> Option<Vec<workspace::LocalPaths>> {
+) -> Option<Vec<SerializedWorkspaceLocation>> {
     let mut restore_behavior = cx
         .update(|cx| WorkspaceSettings::get(None, cx).restore_on_startup)
         .ok()?;
@@ -923,7 +950,7 @@ pub(crate) async fn restorable_workspace_locations(
 
     match restore_behavior {
         workspace::RestoreOnStartupBehavior::LastWorkspace => {
-            workspace::last_opened_workspace_paths()
+            workspace::last_opened_workspace_location()
                 .await
                 .map(|location| vec![location])
         }

crates/zed/src/zed/open_listener.rs 🔗

@@ -16,8 +16,9 @@ use futures::future::join_all;
 use futures::{FutureExt, SinkExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, Global, WindowHandle};
 use language::{Bias, Point};
+use recent_projects::open_ssh_project;
 use remote::SshConnectionOptions;
-use std::path::Path;
+use std::path::{Path, PathBuf};
 use std::sync::Arc;
 use std::time::Duration;
 use std::{process, thread};
@@ -25,7 +26,7 @@ use util::paths::PathWithPosition;
 use util::ResultExt;
 use welcome::{show_welcome_view, FIRST_OPEN};
 use workspace::item::ItemHandle;
-use workspace::{AppState, OpenOptions, Workspace};
+use workspace::{AppState, OpenOptions, SerializedWorkspaceLocation, Workspace};
 
 #[derive(Default, Debug)]
 pub struct OpenRequest {
@@ -356,33 +357,21 @@ async fn open_workspaces(
     env: Option<collections::HashMap<String, String>>,
     cx: &mut AsyncAppContext,
 ) -> Result<()> {
-    let grouped_paths = if paths.is_empty() {
+    let grouped_locations = if paths.is_empty() {
         // If no paths are provided, restore from previous workspaces unless a new workspace is requested with -n
         if open_new_workspace == Some(true) {
             Vec::new()
         } else {
             let locations = restorable_workspace_locations(cx, &app_state).await;
-            locations
-                .into_iter()
-                .flat_map(|locations| {
-                    locations
-                        .into_iter()
-                        .map(|location| {
-                            location
-                                .paths()
-                                .iter()
-                                .map(|path| path.to_string_lossy().to_string())
-                                .collect()
-                        })
-                        .collect::<Vec<_>>()
-                })
-                .collect()
+            locations.unwrap_or_default()
         }
     } else {
-        vec![paths]
+        vec![SerializedWorkspaceLocation::from_local_paths(
+            paths.into_iter().map(PathBuf::from),
+        )]
     };
 
-    if grouped_paths.is_empty() {
+    if grouped_locations.is_empty() {
         // If we have no paths to open, show the welcome screen if this is the first launch
         if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) {
             cx.update(|cx| show_welcome_view(app_state, cx).detach())
@@ -406,20 +395,48 @@ async fn open_workspaces(
         // If there are paths to open, open a workspace for each grouping of paths
         let mut errored = false;
 
-        for workspace_paths in grouped_paths {
-            let workspace_failed_to_open = open_workspace(
-                workspace_paths,
-                open_new_workspace,
-                wait,
-                responses,
-                env.as_ref(),
-                &app_state,
-                cx,
-            )
-            .await;
-
-            if workspace_failed_to_open {
-                errored = true
+        for location in grouped_locations {
+            match location {
+                SerializedWorkspaceLocation::Local(workspace_paths, _) => {
+                    let workspace_paths = workspace_paths
+                        .paths()
+                        .iter()
+                        .map(|path| path.to_string_lossy().to_string())
+                        .collect();
+
+                    let workspace_failed_to_open = open_local_workspace(
+                        workspace_paths,
+                        open_new_workspace,
+                        wait,
+                        responses,
+                        env.as_ref(),
+                        &app_state,
+                        cx,
+                    )
+                    .await;
+
+                    if workspace_failed_to_open {
+                        errored = true
+                    }
+                }
+                SerializedWorkspaceLocation::Ssh(ssh_project) => {
+                    let app_state = app_state.clone();
+                    cx.spawn(|mut cx| async move {
+                        open_ssh_project(
+                            ssh_project.connection_options(),
+                            ssh_project.paths.into_iter().map(PathBuf::from).collect(),
+                            app_state,
+                            OpenOptions::default(),
+                            &mut cx,
+                        )
+                        .await
+                        .log_err();
+                    })
+                    .detach();
+                    // We don't set `errored` here, because for ssh projects, the
+                    // error is displayed in the window.
+                }
+                SerializedWorkspaceLocation::DevServer(_) => {}
             }
         }
 
@@ -431,7 +448,7 @@ async fn open_workspaces(
     Ok(())
 }
 
-async fn open_workspace(
+async fn open_local_workspace(
     workspace_paths: Vec<String>,
     open_new_workspace: Option<bool>,
     wait: bool,
@@ -563,7 +580,7 @@ mod tests {
     use serde_json::json;
     use workspace::{AppState, Workspace};
 
-    use crate::zed::{open_listener::open_workspace, tests::init_test};
+    use crate::zed::{open_listener::open_local_workspace, tests::init_test};
 
     #[gpui::test]
     async fn test_open_workspace_with_directory(cx: &mut TestAppContext) {
@@ -678,7 +695,7 @@ mod tests {
 
         let errored = cx
             .spawn(|mut cx| async move {
-                open_workspace(
+                open_local_workspace(
                     workspace_paths,
                     open_new_workspace,
                     false,