@@ -9,13 +9,13 @@ use std::{
};
use anyhow::{Context as _, Result, bail};
+use collections::HashMap;
use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql};
use gpui::{Axis, Bounds, Task, WindowBounds, WindowId, point, size};
use project::debugger::breakpoint_store::{BreakpointState, SourceBreakpoint};
use language::{LanguageName, Toolchain};
use project::WorktreeId;
-use remote::ssh_session::SshProjectId;
use sqlez::{
bindable::{Bind, Column, StaticColumnCount},
statement::{SqlType, Statement},
@@ -33,7 +33,7 @@ use crate::{
use model::{
GroupId, ItemId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup,
- SerializedSshConnection, SerializedWorkspace,
+ SerializedSshConnection, SerializedWorkspace, SshConnectionId,
};
use self::model::{DockStructure, SerializedWorkspaceLocation};
@@ -615,7 +615,7 @@ impl WorkspaceDb {
pub(crate) fn ssh_workspace_for_roots<P: AsRef<Path>>(
&self,
worktree_roots: &[P],
- ssh_project_id: SshProjectId,
+ ssh_project_id: SshConnectionId,
) -> Option<SerializedWorkspace> {
self.workspace_for_roots_internal(worktree_roots, Some(ssh_project_id))
}
@@ -623,7 +623,7 @@ impl WorkspaceDb {
pub(crate) fn workspace_for_roots_internal<P: AsRef<Path>>(
&self,
worktree_roots: &[P],
- ssh_connection_id: Option<SshProjectId>,
+ ssh_connection_id: Option<SshConnectionId>,
) -> Option<SerializedWorkspace> {
// paths are sorted before db interactions to ensure that the order of the paths
// doesn't affect the workspace selection for existing workspaces
@@ -762,15 +762,21 @@ impl WorkspaceDb {
/// that used this workspace previously
pub(crate) async fn save_workspace(&self, workspace: SerializedWorkspace) {
let paths = workspace.paths.serialize();
- let ssh_connection_id = match &workspace.location {
- SerializedWorkspaceLocation::Local => None,
- SerializedWorkspaceLocation::Ssh(serialized_ssh_connection) => {
- Some(serialized_ssh_connection.id.0)
- }
- };
log::debug!("Saving workspace at location: {:?}", workspace.location);
self.write(move |conn| {
conn.with_savepoint("update_worktrees", || {
+ let ssh_connection_id = match &workspace.location {
+ SerializedWorkspaceLocation::Local => None,
+ SerializedWorkspaceLocation::Ssh(connection) => {
+ Some(Self::get_or_create_ssh_connection_query(
+ conn,
+ connection.host.clone(),
+ connection.port,
+ connection.user.clone(),
+ )?.0)
+ }
+ };
+
// Clear out panes and pane_groups
conn.exec_bound(sql!(
DELETE FROM pane_groups WHERE workspace_id = ?1;
@@ -893,39 +899,34 @@ impl WorkspaceDb {
host: String,
port: Option<u16>,
user: Option<String>,
- ) -> Result<SshProjectId> {
- if let Some(id) = self
- .get_ssh_connection(host.clone(), port, user.clone())
- .await?
+ ) -> Result<SshConnectionId> {
+ self.write(move |conn| Self::get_or_create_ssh_connection_query(conn, host, port, user))
+ .await
+ }
+
+ fn get_or_create_ssh_connection_query(
+ this: &Connection,
+ host: String,
+ port: Option<u16>,
+ user: Option<String>,
+ ) -> Result<SshConnectionId> {
+ if let Some(id) = this.select_row_bound(sql!(
+ SELECT id FROM ssh_connections WHERE host IS ? AND port IS ? AND user IS ? LIMIT 1
+ ))?((host.clone(), port, user.clone()))?
{
- Ok(SshProjectId(id))
+ Ok(SshConnectionId(id))
} else {
log::debug!("Inserting SSH project at host {host}");
- let id = self
- .insert_ssh_connection(host, port, user)
- .await?
- .context("failed to insert ssh project")?;
- Ok(SshProjectId(id))
- }
- }
-
- query! {
- async fn get_ssh_connection(host: String, port: Option<u16>, user: Option<String>) -> Result<Option<u64>> {
- SELECT id
- FROM ssh_connections
- WHERE host IS ? AND port IS ? AND user IS ?
- LIMIT 1
- }
- }
-
- query! {
- async fn insert_ssh_connection(host: String, port: Option<u16>, user: Option<String>) -> Result<Option<u64>> {
- INSERT INTO ssh_connections (
- host,
- port,
- user
- ) VALUES (?1, ?2, ?3)
- RETURNING id
+ let id = this.select_row_bound(sql!(
+ INSERT INTO ssh_connections (
+ host,
+ port,
+ user
+ ) VALUES (?1, ?2, ?3)
+ RETURNING id
+ ))?((host, port, user))?
+ .context("failed to insert ssh project")?;
+ Ok(SshConnectionId(id))
}
}
@@ -963,7 +964,7 @@ impl WorkspaceDb {
fn session_workspaces(
&self,
session_id: String,
- ) -> Result<Vec<(PathList, Option<u64>, Option<SshProjectId>)>> {
+ ) -> Result<Vec<(PathList, Option<u64>, Option<SshConnectionId>)>> {
Ok(self
.session_workspaces_query(session_id)?
.into_iter()
@@ -971,7 +972,7 @@ impl WorkspaceDb {
(
PathList::deserialize(&SerializedPathList { paths, order }),
window_id,
- ssh_connection_id.map(SshProjectId),
+ ssh_connection_id.map(SshConnectionId),
)
})
.collect())
@@ -1001,15 +1002,15 @@ impl WorkspaceDb {
}
}
- fn ssh_connections(&self) -> Result<Vec<SerializedSshConnection>> {
+ fn ssh_connections(&self) -> Result<HashMap<SshConnectionId, SerializedSshConnection>> {
Ok(self
.ssh_connections_query()?
.into_iter()
- .map(|(id, host, port, user)| SerializedSshConnection {
- id: SshProjectId(id),
- host,
- port,
- user,
+ .map(|(id, host, port, user)| {
+ (
+ SshConnectionId(id),
+ SerializedSshConnection { host, port, user },
+ )
})
.collect())
}
@@ -1021,19 +1022,18 @@ impl WorkspaceDb {
}
}
- pub fn ssh_connection(&self, id: SshProjectId) -> Result<SerializedSshConnection> {
+ pub(crate) fn ssh_connection(&self, id: SshConnectionId) -> Result<SerializedSshConnection> {
let row = self.ssh_connection_query(id.0)?;
Ok(SerializedSshConnection {
- id: SshProjectId(row.0),
- host: row.1,
- port: row.2,
- user: row.3,
+ host: row.0,
+ port: row.1,
+ user: row.2,
})
}
query! {
- fn ssh_connection_query(id: u64) -> Result<(u64, String, Option<u16>, Option<String>)> {
- SELECT id, host, port, user
+ fn ssh_connection_query(id: u64) -> Result<(String, Option<u16>, Option<String>)> {
+ SELECT host, port, user
FROM ssh_connections
WHERE id = ?
}
@@ -1075,10 +1075,8 @@ impl WorkspaceDb {
let ssh_connections = self.ssh_connections()?;
for (id, paths, ssh_connection_id) in self.recent_workspaces()? {
- if let Some(ssh_connection_id) = ssh_connection_id.map(SshProjectId) {
- if let Some(ssh_connection) =
- ssh_connections.iter().find(|rp| rp.id == ssh_connection_id)
- {
+ if let Some(ssh_connection_id) = ssh_connection_id.map(SshConnectionId) {
+ if let Some(ssh_connection) = ssh_connections.get(&ssh_connection_id) {
result.push((
id,
SerializedWorkspaceLocation::Ssh(ssh_connection.clone()),
@@ -2340,12 +2338,10 @@ mod tests {
]
.into_iter()
.map(|(host, user)| async {
- let id = db
- .get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string()))
+ db.get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string()))
.await
.unwrap();
SerializedSshConnection {
- id,
host: host.into(),
port: None,
user: Some(user.into()),
@@ -2501,26 +2497,34 @@ mod tests {
let stored_projects = db.ssh_connections().unwrap();
assert_eq!(
stored_projects,
- &[
- SerializedSshConnection {
- id: ids[0],
- host: "example.com".into(),
- port: None,
- user: None,
- },
- SerializedSshConnection {
- id: ids[1],
- host: "anotherexample.com".into(),
- port: Some(123),
- user: Some("user2".into()),
- },
- SerializedSshConnection {
- id: ids[2],
- host: "yetanother.com".into(),
- port: Some(345),
- user: None,
- },
+ [
+ (
+ ids[0],
+ SerializedSshConnection {
+ host: "example.com".into(),
+ port: None,
+ user: None,
+ }
+ ),
+ (
+ ids[1],
+ SerializedSshConnection {
+ host: "anotherexample.com".into(),
+ port: Some(123),
+ user: Some("user2".into()),
+ }
+ ),
+ (
+ ids[2],
+ SerializedSshConnection {
+ host: "yetanother.com".into(),
+ port: Some(345),
+ user: None,
+ }
+ ),
]
+ .into_iter()
+ .collect::<HashMap<_, _>>(),
);
}
@@ -74,10 +74,7 @@ use project::{
DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId,
debugger::{breakpoint_store::BreakpointStoreEvent, session::ThreadStatus},
};
-use remote::{
- SshClientDelegate, SshConnectionOptions,
- ssh_session::{ConnectionIdentifier, SshProjectId},
-};
+use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIdentifier};
use schemars::JsonSchema;
use serde::Deserialize;
use session::AppSession;
@@ -1128,7 +1125,6 @@ pub struct Workspace {
terminal_provider: Option<Box<dyn TerminalProvider>>,
debugger_provider: Option<Arc<dyn DebuggerProvider>>,
serializable_items_tx: UnboundedSender<Box<dyn SerializableItemHandle>>,
- serialized_ssh_connection_id: Option<SshProjectId>,
_items_serializer: Task<Result<()>>,
session_id: Option<String>,
scheduled_tasks: Vec<Task<()>>,
@@ -1461,7 +1457,7 @@ impl Workspace {
serializable_items_tx,
_items_serializer,
session_id: Some(session_id),
- serialized_ssh_connection_id: None,
+
scheduled_tasks: Vec::new(),
}
}
@@ -5288,11 +5284,9 @@ impl Workspace {
fn serialize_workspace_location(&self, cx: &App) -> WorkspaceLocation {
let paths = PathList::new(&self.root_paths(cx));
- let connection = self.project.read(cx).ssh_connection_options(cx);
- if let Some((id, connection)) = self.serialized_ssh_connection_id.zip(connection) {
+ if let Some(connection) = self.project.read(cx).ssh_connection_options(cx) {
WorkspaceLocation::Location(
SerializedWorkspaceLocation::Ssh(SerializedSshConnection {
- id,
host: connection.host,
port: connection.port,
user: connection.username,