Clean up handling of serialized ssh connection ids (#36781)

Max Brunsfeld created

Small follow-up to #36714

Release Notes:

- N/A

Change summary

crates/remote/src/ssh_session.rs          |   5 
crates/workspace/src/persistence.rs       | 166 ++++++++++++------------
crates/workspace/src/persistence/model.rs |   7 
crates/workspace/src/workspace.rs         |  12 -
4 files changed, 93 insertions(+), 97 deletions(-)

Detailed changes

crates/remote/src/ssh_session.rs 🔗

@@ -52,11 +52,6 @@ use util::{
     paths::{PathStyle, RemotePathBuf},
 };
 
-#[derive(
-    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
-)]
-pub struct SshProjectId(pub u64);
-
 #[derive(Clone)]
 pub struct SshSocket {
     connection_options: SshConnectionOptions,

crates/workspace/src/persistence.rs 🔗

@@ -9,13 +9,13 @@ use std::{
 };
 
 use anyhow::{Context as _, Result, bail};
+use collections::HashMap;
 use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql};
 use gpui::{Axis, Bounds, Task, WindowBounds, WindowId, point, size};
 use project::debugger::breakpoint_store::{BreakpointState, SourceBreakpoint};
 
 use language::{LanguageName, Toolchain};
 use project::WorktreeId;
-use remote::ssh_session::SshProjectId;
 use sqlez::{
     bindable::{Bind, Column, StaticColumnCount},
     statement::{SqlType, Statement},
@@ -33,7 +33,7 @@ use crate::{
 
 use model::{
     GroupId, ItemId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup,
-    SerializedSshConnection, SerializedWorkspace,
+    SerializedSshConnection, SerializedWorkspace, SshConnectionId,
 };
 
 use self::model::{DockStructure, SerializedWorkspaceLocation};
@@ -615,7 +615,7 @@ impl WorkspaceDb {
     pub(crate) fn ssh_workspace_for_roots<P: AsRef<Path>>(
         &self,
         worktree_roots: &[P],
-        ssh_project_id: SshProjectId,
+        ssh_project_id: SshConnectionId,
     ) -> Option<SerializedWorkspace> {
         self.workspace_for_roots_internal(worktree_roots, Some(ssh_project_id))
     }
@@ -623,7 +623,7 @@ impl WorkspaceDb {
     pub(crate) fn workspace_for_roots_internal<P: AsRef<Path>>(
         &self,
         worktree_roots: &[P],
-        ssh_connection_id: Option<SshProjectId>,
+        ssh_connection_id: Option<SshConnectionId>,
     ) -> Option<SerializedWorkspace> {
         // paths are sorted before db interactions to ensure that the order of the paths
         // doesn't affect the workspace selection for existing workspaces
@@ -762,15 +762,21 @@ impl WorkspaceDb {
     /// that used this workspace previously
     pub(crate) async fn save_workspace(&self, workspace: SerializedWorkspace) {
         let paths = workspace.paths.serialize();
-        let ssh_connection_id = match &workspace.location {
-            SerializedWorkspaceLocation::Local => None,
-            SerializedWorkspaceLocation::Ssh(serialized_ssh_connection) => {
-                Some(serialized_ssh_connection.id.0)
-            }
-        };
         log::debug!("Saving workspace at location: {:?}", workspace.location);
         self.write(move |conn| {
             conn.with_savepoint("update_worktrees", || {
+                let ssh_connection_id = match &workspace.location {
+                    SerializedWorkspaceLocation::Local => None,
+                    SerializedWorkspaceLocation::Ssh(connection) => {
+                        Some(Self::get_or_create_ssh_connection_query(
+                            conn,
+                            connection.host.clone(),
+                            connection.port,
+                            connection.user.clone(),
+                        )?.0)
+                    }
+                };
+
                 // Clear out panes and pane_groups
                 conn.exec_bound(sql!(
                     DELETE FROM pane_groups WHERE workspace_id = ?1;
@@ -893,39 +899,34 @@ impl WorkspaceDb {
         host: String,
         port: Option<u16>,
         user: Option<String>,
-    ) -> Result<SshProjectId> {
-        if let Some(id) = self
-            .get_ssh_connection(host.clone(), port, user.clone())
-            .await?
+    ) -> Result<SshConnectionId> {
+        self.write(move |conn| Self::get_or_create_ssh_connection_query(conn, host, port, user))
+            .await
+    }
+
+    fn get_or_create_ssh_connection_query(
+        this: &Connection,
+        host: String,
+        port: Option<u16>,
+        user: Option<String>,
+    ) -> Result<SshConnectionId> {
+        if let Some(id) = this.select_row_bound(sql!(
+            SELECT id FROM ssh_connections WHERE host IS ? AND port IS ? AND user IS ? LIMIT 1
+        ))?((host.clone(), port, user.clone()))?
         {
-            Ok(SshProjectId(id))
+            Ok(SshConnectionId(id))
         } else {
             log::debug!("Inserting SSH project at host {host}");
-            let id = self
-                .insert_ssh_connection(host, port, user)
-                .await?
-                .context("failed to insert ssh project")?;
-            Ok(SshProjectId(id))
-        }
-    }
-
-    query! {
-        async fn get_ssh_connection(host: String, port: Option<u16>, user: Option<String>) -> Result<Option<u64>> {
-            SELECT id
-            FROM ssh_connections
-            WHERE host IS ? AND port IS ? AND user IS ?
-            LIMIT 1
-        }
-    }
-
-    query! {
-        async fn insert_ssh_connection(host: String, port: Option<u16>, user: Option<String>) -> Result<Option<u64>> {
-            INSERT INTO ssh_connections (
-                host,
-                port,
-                user
-            ) VALUES (?1, ?2, ?3)
-            RETURNING id
+            let id = this.select_row_bound(sql!(
+                INSERT INTO ssh_connections (
+                    host,
+                    port,
+                    user
+                ) VALUES (?1, ?2, ?3)
+                RETURNING id
+            ))?((host, port, user))?
+            .context("failed to insert ssh project")?;
+            Ok(SshConnectionId(id))
         }
     }
 
@@ -963,7 +964,7 @@ impl WorkspaceDb {
     fn session_workspaces(
         &self,
         session_id: String,
-    ) -> Result<Vec<(PathList, Option<u64>, Option<SshProjectId>)>> {
+    ) -> Result<Vec<(PathList, Option<u64>, Option<SshConnectionId>)>> {
         Ok(self
             .session_workspaces_query(session_id)?
             .into_iter()
@@ -971,7 +972,7 @@ impl WorkspaceDb {
                 (
                     PathList::deserialize(&SerializedPathList { paths, order }),
                     window_id,
-                    ssh_connection_id.map(SshProjectId),
+                    ssh_connection_id.map(SshConnectionId),
                 )
             })
             .collect())
@@ -1001,15 +1002,15 @@ impl WorkspaceDb {
         }
     }
 
-    fn ssh_connections(&self) -> Result<Vec<SerializedSshConnection>> {
+    fn ssh_connections(&self) -> Result<HashMap<SshConnectionId, SerializedSshConnection>> {
         Ok(self
             .ssh_connections_query()?
             .into_iter()
-            .map(|(id, host, port, user)| SerializedSshConnection {
-                id: SshProjectId(id),
-                host,
-                port,
-                user,
+            .map(|(id, host, port, user)| {
+                (
+                    SshConnectionId(id),
+                    SerializedSshConnection { host, port, user },
+                )
             })
             .collect())
     }
@@ -1021,19 +1022,18 @@ impl WorkspaceDb {
         }
     }
 
-    pub fn ssh_connection(&self, id: SshProjectId) -> Result<SerializedSshConnection> {
+    pub(crate) fn ssh_connection(&self, id: SshConnectionId) -> Result<SerializedSshConnection> {
         let row = self.ssh_connection_query(id.0)?;
         Ok(SerializedSshConnection {
-            id: SshProjectId(row.0),
-            host: row.1,
-            port: row.2,
-            user: row.3,
+            host: row.0,
+            port: row.1,
+            user: row.2,
         })
     }
 
     query! {
-        fn ssh_connection_query(id: u64) -> Result<(u64, String, Option<u16>, Option<String>)> {
-            SELECT id, host, port, user
+        fn ssh_connection_query(id: u64) -> Result<(String, Option<u16>, Option<String>)> {
+            SELECT host, port, user
             FROM ssh_connections
             WHERE id = ?
         }
@@ -1075,10 +1075,8 @@ impl WorkspaceDb {
         let ssh_connections = self.ssh_connections()?;
 
         for (id, paths, ssh_connection_id) in self.recent_workspaces()? {
-            if let Some(ssh_connection_id) = ssh_connection_id.map(SshProjectId) {
-                if let Some(ssh_connection) =
-                    ssh_connections.iter().find(|rp| rp.id == ssh_connection_id)
-                {
+            if let Some(ssh_connection_id) = ssh_connection_id.map(SshConnectionId) {
+                if let Some(ssh_connection) = ssh_connections.get(&ssh_connection_id) {
                     result.push((
                         id,
                         SerializedWorkspaceLocation::Ssh(ssh_connection.clone()),
@@ -2340,12 +2338,10 @@ mod tests {
         ]
         .into_iter()
         .map(|(host, user)| async {
-            let id = db
-                .get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string()))
+            db.get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string()))
                 .await
                 .unwrap();
             SerializedSshConnection {
-                id,
                 host: host.into(),
                 port: None,
                 user: Some(user.into()),
@@ -2501,26 +2497,34 @@ mod tests {
         let stored_projects = db.ssh_connections().unwrap();
         assert_eq!(
             stored_projects,
-            &[
-                SerializedSshConnection {
-                    id: ids[0],
-                    host: "example.com".into(),
-                    port: None,
-                    user: None,
-                },
-                SerializedSshConnection {
-                    id: ids[1],
-                    host: "anotherexample.com".into(),
-                    port: Some(123),
-                    user: Some("user2".into()),
-                },
-                SerializedSshConnection {
-                    id: ids[2],
-                    host: "yetanother.com".into(),
-                    port: Some(345),
-                    user: None,
-                },
+            [
+                (
+                    ids[0],
+                    SerializedSshConnection {
+                        host: "example.com".into(),
+                        port: None,
+                        user: None,
+                    }
+                ),
+                (
+                    ids[1],
+                    SerializedSshConnection {
+                        host: "anotherexample.com".into(),
+                        port: Some(123),
+                        user: Some("user2".into()),
+                    }
+                ),
+                (
+                    ids[2],
+                    SerializedSshConnection {
+                        host: "yetanother.com".into(),
+                        port: Some(345),
+                        user: None,
+                    }
+                ),
             ]
+            .into_iter()
+            .collect::<HashMap<_, _>>(),
         );
     }
 

crates/workspace/src/persistence/model.rs 🔗

@@ -12,7 +12,6 @@ use db::sqlez::{
 use gpui::{AsyncWindowContext, Entity, WeakEntity};
 
 use project::{Project, debugger::breakpoint_store::SourceBreakpoint};
-use remote::ssh_session::SshProjectId;
 use serde::{Deserialize, Serialize};
 use std::{
     collections::BTreeMap,
@@ -22,9 +21,13 @@ use std::{
 use util::ResultExt;
 use uuid::Uuid;
 
+#[derive(
+    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
+)]
+pub(crate) struct SshConnectionId(pub u64);
+
 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
 pub struct SerializedSshConnection {
-    pub id: SshProjectId,
     pub host: String,
     pub port: Option<u16>,
     pub user: Option<String>,

crates/workspace/src/workspace.rs 🔗

@@ -74,10 +74,7 @@ use project::{
     DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId,
     debugger::{breakpoint_store::BreakpointStoreEvent, session::ThreadStatus},
 };
-use remote::{
-    SshClientDelegate, SshConnectionOptions,
-    ssh_session::{ConnectionIdentifier, SshProjectId},
-};
+use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIdentifier};
 use schemars::JsonSchema;
 use serde::Deserialize;
 use session::AppSession;
@@ -1128,7 +1125,6 @@ pub struct Workspace {
     terminal_provider: Option<Box<dyn TerminalProvider>>,
     debugger_provider: Option<Arc<dyn DebuggerProvider>>,
     serializable_items_tx: UnboundedSender<Box<dyn SerializableItemHandle>>,
-    serialized_ssh_connection_id: Option<SshProjectId>,
     _items_serializer: Task<Result<()>>,
     session_id: Option<String>,
     scheduled_tasks: Vec<Task<()>>,
@@ -1461,7 +1457,7 @@ impl Workspace {
             serializable_items_tx,
             _items_serializer,
             session_id: Some(session_id),
-            serialized_ssh_connection_id: None,
+
             scheduled_tasks: Vec::new(),
         }
     }
@@ -5288,11 +5284,9 @@ impl Workspace {
 
     fn serialize_workspace_location(&self, cx: &App) -> WorkspaceLocation {
         let paths = PathList::new(&self.root_paths(cx));
-        let connection = self.project.read(cx).ssh_connection_options(cx);
-        if let Some((id, connection)) = self.serialized_ssh_connection_id.zip(connection) {
+        if let Some(connection) = self.project.read(cx).ssh_connection_options(cx) {
             WorkspaceLocation::Location(
                 SerializedWorkspaceLocation::Ssh(SerializedSshConnection {
-                    id,
                     host: connection.host,
                     port: connection.port,
                     user: connection.username,