diff --git a/Cargo.lock b/Cargo.lock index 7973316c9e337532685bfc546f22dae9d8a7bee1..483cf758880de306df5a599a61353769ced2f43c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -867,6 +867,9 @@ name = "cc" version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd" +dependencies = [ + "jobserver", +] [[package]] name = "cexpr" @@ -2614,6 +2617,15 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" +[[package]] +name = "jobserver" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa" +dependencies = [ + "libc", +] + [[package]] name = "jpeg-decoder" version = "0.1.22" @@ -6037,4 +6049,34 @@ dependencies = [ "serde 1.0.125", "smol", "tempdir", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.9.0+zstd.1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07749a5dc2cb6b36661290245e350f15ec3bbb304e493db54a1d354480522ccd" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "4.1.1+zstd.1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91c90f2c593b003603e5e0493c837088df4469da25aafff8bce42ba48caf079" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "1.6.1+zstd.1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "615120c7a2431d16cf1cf979e7fc31ba7a5b5e5707b29c8a99e5dbf8a8392a33" +dependencies = [ + "cc", + "libc", ] diff --git a/gpui/src/text_layout.rs b/gpui/src/text_layout.rs index c74949b9981a90ba0554b883dcdd6a254c21adaf..a7b976d72c746fe010ccd3a3ad0ebccdd8884917 100644 --- a/gpui/src/text_layout.rs +++ b/gpui/src/text_layout.rs @@ -669,8 +669,9 @@ mod tests { ); } - #[crate::test(self)] + #[crate::test(self, retries = 5)] fn test_wrap_shaped_line(cx: &mut crate::MutableAppContext) { + // This is failing intermittently on CI and we don't have time to figure it out let font_cache = cx.font_cache().clone(); let font_system = cx.platform().fonts(); let text_layout_cache = TextLayoutCache::new(font_system.clone()); diff --git a/server/src/auth.rs b/server/src/auth.rs index e60802285ec602a058fb0c74c2b4038603d6312a..4a06c642eb2554128b06e477c1e934a1fde83b36 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -18,7 +18,7 @@ use scrypt::{ use serde::{Deserialize, Serialize}; use std::{borrow::Cow, convert::TryFrom, sync::Arc}; use surf::{StatusCode, Url}; -use tide::{log, Server}; +use tide::{log, Error, Server}; use zrpc::auth as zed_auth; static CURRENT_GITHUB_USER: &'static str = "current_github_user"; @@ -33,51 +33,48 @@ pub struct User { pub is_admin: bool, } -pub struct VerifyToken; - -#[async_trait] -impl tide::Middleware> for VerifyToken { - async fn handle( - &self, - mut request: Request, - next: tide::Next<'_, Arc>, - ) -> tide::Result { - let mut auth_header = request - .header("Authorization") - .ok_or_else(|| anyhow!("no authorization header"))? - .last() - .as_str() - .split_whitespace(); - - let user_id = UserId( - auth_header - .next() - .ok_or_else(|| anyhow!("missing user id in authorization header"))? - .parse()?, - ); - let access_token = auth_header - .next() - .ok_or_else(|| anyhow!("missing access token in authorization header"))?; - - let state = request.state().clone(); - - let mut credentials_valid = false; - for password_hash in state.db.get_access_token_hashes(user_id).await? { - if verify_access_token(&access_token, &password_hash)? { - credentials_valid = true; - break; - } +pub async fn process_auth_header(request: &Request) -> tide::Result { + let mut auth_header = request + .header("Authorization") + .ok_or_else(|| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing authorization header"), + ) + })? + .last() + .as_str() + .split_whitespace(); + let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing user id in authorization header"), + ) + })?); + let access_token = auth_header.next().ok_or_else(|| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing access token in authorization header"), + ) + })?; + + let state = request.state().clone(); + let mut credentials_valid = false; + for password_hash in state.db.get_access_token_hashes(user_id).await? { + if verify_access_token(&access_token, &password_hash)? { + credentials_valid = true; + break; } + } - if credentials_valid { - request.set_ext(user_id); - Ok(next.run(request).await) - } else { - let mut response = tide::Response::new(StatusCode::Unauthorized); - response.set_body("invalid credentials"); - Ok(response) - } + if !credentials_valid { + Err(Error::new( + StatusCode::Unauthorized, + anyhow!("invalid credentials"), + ))?; } + + Ok(user_id) } #[async_trait] @@ -263,11 +260,13 @@ async fn post_sign_out(mut request: Request) -> tide::Result { Ok(tide::Redirect::new("/").into()) } +const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; + pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result { let access_token = zed_auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; - db.create_access_token_hash(user_id, access_token_hash) + db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) .await?; Ok(access_token) } diff --git a/server/src/db.rs b/server/src/db.rs index 94bfc0013457b96038bbbee27d8018f81290fe7e..ebc861be03d5ddaaa6a85165812a60a9f4494ac5 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -175,25 +175,48 @@ impl Db { pub async fn create_access_token_hash( &self, user_id: UserId, - access_token_hash: String, + access_token_hash: &str, + max_access_token_count: usize, ) -> Result<()> { test_support!(self, { - let query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2) - "; - sqlx::query(query) + let insert_query = " + INSERT INTO access_tokens (user_id, hash) + VALUES ($1, $2); + "; + let cleanup_query = " + DELETE FROM access_tokens + WHERE id IN ( + SELECT id from access_tokens + WHERE user_id = $1 + ORDER BY id DESC + OFFSET $3 + ) + "; + + let mut tx = self.pool.begin().await?; + sqlx::query(insert_query) .bind(user_id.0) .bind(access_token_hash) - .execute(&self.pool) - .await - .map(drop) + .execute(&mut tx) + .await?; + sqlx::query(cleanup_query) + .bind(user_id.0) + .bind(access_token_hash) + .bind(max_access_token_count as u32) + .execute(&mut tx) + .await?; + tx.commit().await }) } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { test_support!(self, { - let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; + let query = " + SELECT hash + FROM access_tokens + WHERE user_id = $1 + ORDER BY id DESC + "; sqlx::query_scalar(query) .bind(user_id.0) .fetch_all(&self.pool) @@ -652,4 +675,36 @@ pub mod tests { assert_eq!(msg1_id, msg3_id); assert_eq!(msg2_id, msg4_id); } -} \ No newline at end of file + + #[gpui::test] + async fn test_create_access_tokens() { + let test_db = TestDb::new(); + let db = test_db.db(); + let user = db.create_user("the-user", false).await.unwrap(); + + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); + + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); + + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); + + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +} diff --git a/server/src/rpc.rs b/server/src/rpc.rs index fec6182fcc1ff02cf696fed5cfcba32f41564af4..a9ffdad8997d8e2338223b0afdfdd05e3981fd1b 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,7 +1,7 @@ mod store; use super::{ - auth, + auth::process_auth_header, db::{ChannelId, MessageId, UserId}, AppState, }; @@ -885,8 +885,7 @@ where pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let server = Server::new(app.state().clone(), rpc.clone(), None); - app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { - let user_id = request.ext::().copied(); + app.at("/rpc").get(move |request: Request>| { let server = server.clone(); async move { const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -894,8 +893,11 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade"); let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket"); let upgrade_requested = connection_upgrade && upgrade_to_websocket; + let client_protocol_version: Option = request + .header("X-Zed-Protocol-Version") + .and_then(|v| v.as_str().parse().ok()); - if !upgrade_requested { + if !upgrade_requested || client_protocol_version != Some(zrpc::PROTOCOL_VERSION) { return Ok(Response::new(StatusCode::UpgradeRequired)); } @@ -904,6 +906,8 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { None => return Err(anyhow!("expected sec-websocket-key"))?, }; + let user_id = process_auth_header(&request).await?; + let mut response = Response::new(StatusCode::SwitchingProtocols); response.insert_header(UPGRADE, "websocket"); response.insert_header(CONNECTION, "Upgrade"); @@ -914,10 +918,17 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let http_res: &mut tide::http::Response = response.as_mut(); let upgrade_receiver = http_res.recv_upgrade().await; let addr = request.remote().unwrap_or("unknown").to_string(); - let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { - server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; + server + .handle_connection( + Connection::new( + WebSocketStream::from_raw_socket(stream, Role::Server, None).await, + ), + addr, + user_id, + ) + .await; } }); diff --git a/zed/assets/themes/_base.toml b/zed/assets/themes/_base.toml index 5938032c2c8c099f2840460f3f5a640dd64196d6..3e326ab62c6168673be6888778b3cd11317829e5 100644 --- a/zed/assets/themes/_base.toml +++ b/zed/assets/themes/_base.toml @@ -11,6 +11,7 @@ title = "$text.0" avatar_width = 20 icon_color = "$text.2.color" avatar = { corner_radius = 10, border = { width = 1, color = "#00000088" } } +outdated_warning = { extends = "$text.2", size = 13 } [workspace.titlebar.offline_icon] padding = { right = 4 } diff --git a/zed/src/editor.rs b/zed/src/editor.rs index a06b7ee013787591152894b2d0f911adf803f388..4d0ead1f1c54651081797262ce1188dc1858a05a 100644 --- a/zed/src/editor.rs +++ b/zed/src/editor.rs @@ -29,7 +29,6 @@ use smol::Timer; use std::{ cell::RefCell, cmp::{self, Ordering}, - iter::FromIterator, mem, ops::{Range, RangeInclusive}, path::Path, @@ -299,7 +298,7 @@ pub struct Editor { pending_selection: Option, next_selection_id: usize, add_selections_state: Option, - select_larger_syntax_node_stack: Vec>, + select_larger_syntax_node_stack: Vec>, scroll_position: Vector2F, scroll_top_anchor: Anchor, autoscroll_requested: bool, @@ -511,15 +510,14 @@ impl Editor { return false; } - let first_cursor_top = self - .selections(cx) + let selections = self.selections(cx); + let first_cursor_top = selections .first() .unwrap() .head() .to_display_point(&display_map, Bias::Left) .row() as f32; - let last_cursor_bottom = self - .selections(cx) + let last_cursor_bottom = selections .last() .unwrap() .head() @@ -561,12 +559,13 @@ impl Editor { scroll_width: f32, max_glyph_width: f32, layouts: &[text_layout::Line], - cx: &mut MutableAppContext, + cx: &mut ViewContext, ) -> bool { + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let mut target_left = std::f32::INFINITY; let mut target_right = 0.0_f32; - for selection in self.selections(cx) { + for selection in selections.iter() { let head = selection.head().to_display_point(&display_map, Bias::Left); let start_column = head.column().saturating_sub(3); let end_column = cmp::min(display_map.line_len(head.row()), head.column() + 3); @@ -655,12 +654,10 @@ impl Editor { fn end_selection(&mut self, cx: &mut ViewContext) { if let Some(selection) = self.pending_selection.take() { - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); let ix = self.selection_insertion_index(&selections, &selection.start, cx.as_ref()); selections.insert(ix, selection); self.update_selections(selections, false, cx); - } else { - log::error!("end_selection dispatched with no pending selection"); } } @@ -669,12 +666,13 @@ impl Editor { } pub fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext) { - let selections = self.selections(cx.as_ref()); if let Some(pending_selection) = self.pending_selection.take() { + let selections = self.selections(cx); if selections.is_empty() { self.update_selections(vec![pending_selection], true, cx); } } else { + let selections = self.selections(cx); let mut oldest_selection = selections.iter().min_by_key(|s| s.id).unwrap().clone(); if selections.len() == 1 { oldest_selection.start = oldest_selection.head().clone(); @@ -743,8 +741,9 @@ impl Editor { pub fn insert(&mut self, action: &Insert, cx: &mut ViewContext) { let mut old_selections = SmallVec::<[_; 32]>::new(); { + let selections = self.selections(cx); let buffer = self.buffer.read(cx); - for selection in self.selections(cx.as_ref()) { + for selection in selections.iter() { let start = selection.start.to_offset(buffer); let end = selection.end.to_offset(buffer); old_selections.push((selection.id, start..end)); @@ -790,7 +789,7 @@ impl Editor { pub fn backspace(&mut self, _: &Backspace, cx: &mut ViewContext) { self.start_transaction(cx); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); { let buffer = self.buffer.read(cx); @@ -814,7 +813,7 @@ impl Editor { pub fn delete(&mut self, _: &Delete, cx: &mut ViewContext) { self.start_transaction(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { let buffer = self.buffer.read(cx); for selection in &mut selections { @@ -837,14 +836,14 @@ impl Editor { pub fn delete_line(&mut self, _: &DeleteLine, cx: &mut ViewContext) { self.start_transaction(cx); + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let app = cx.as_ref(); - let buffer = self.buffer.read(app); + let buffer = self.buffer.read(cx); let mut new_cursors = Vec::new(); let mut edit_ranges = Vec::new(); - let mut selections = self.selections(app).iter().peekable(); + let mut selections = selections.iter().peekable(); while let Some(selection) = selections.next() { let mut rows = selection.spanned_rows(false, &display_map).buffer_rows; let goal_display_column = selection @@ -914,7 +913,7 @@ impl Editor { pub fn duplicate_line(&mut self, _: &DuplicateLine, cx: &mut ViewContext) { self.start_transaction(cx); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { // Temporarily bias selections right to allow newly duplicate lines to push them down // when the selections are at the beginning of a line. @@ -974,8 +973,8 @@ impl Editor { pub fn move_line_up(&mut self, _: &MoveLineUp, cx: &mut ViewContext) { self.start_transaction(cx); + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let app = cx.as_ref(); let buffer = self.buffer.read(cx); let mut edits = Vec::new(); @@ -983,7 +982,7 @@ impl Editor { let mut old_folds = Vec::new(); let mut new_folds = Vec::new(); - let mut selections = self.selections(app).iter().peekable(); + let mut selections = selections.iter().peekable(); let mut contiguous_selections = Vec::new(); while let Some(selection) = selections.next() { // Accumulate contiguous regions of rows that we want to move. @@ -1064,8 +1063,8 @@ impl Editor { pub fn move_line_down(&mut self, _: &MoveLineDown, cx: &mut ViewContext) { self.start_transaction(cx); + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let app = cx.as_ref(); let buffer = self.buffer.read(cx); let mut edits = Vec::new(); @@ -1073,7 +1072,7 @@ impl Editor { let mut old_folds = Vec::new(); let mut new_folds = Vec::new(); - let mut selections = self.selections(app).iter().peekable(); + let mut selections = selections.iter().peekable(); let mut contiguous_selections = Vec::new(); while let Some(selection) = selections.next() { // Accumulate contiguous regions of rows that we want to move. @@ -1151,7 +1150,7 @@ impl Editor { pub fn cut(&mut self, _: &Cut, cx: &mut ViewContext) { self.start_transaction(cx); let mut text = String::new(); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); let mut clipboard_selections = Vec::with_capacity(selections.len()); { let buffer = self.buffer.read(cx); @@ -1186,12 +1185,12 @@ impl Editor { } pub fn copy(&mut self, _: &Copy, cx: &mut ViewContext) { + let selections = self.selections(cx); let buffer = self.buffer.read(cx); let max_point = buffer.max_point(); let mut text = String::new(); - let selections = self.selections(cx.as_ref()); let mut clipboard_selections = Vec::with_capacity(selections.len()); - for selection in selections { + for selection in selections.iter() { let mut start = selection.start.to_point(buffer); let mut end = selection.end.to_point(buffer); let is_entire_line = start == end; @@ -1218,24 +1217,28 @@ impl Editor { if let Some(item) = cx.as_mut().read_from_clipboard() { let clipboard_text = item.text(); if let Some(mut clipboard_selections) = item.metadata::>() { - let selections = self.selections(cx.as_ref()).to_vec(); + let selections = self.selections(cx); + let all_selections_were_entire_line = + clipboard_selections.iter().all(|s| s.is_entire_line); if clipboard_selections.len() != selections.len() { - let merged_selection = ClipboardSelection { - len: clipboard_selections.iter().map(|s| s.len).sum(), - is_entire_line: clipboard_selections.iter().all(|s| s.is_entire_line), - }; clipboard_selections.clear(); - clipboard_selections.push(merged_selection); } self.start_transaction(cx); + let mut start_offset = 0; let mut new_selections = Vec::with_capacity(selections.len()); - let mut clipboard_chars = clipboard_text.chars().cycle(); - for (selection, clipboard_selection) in - selections.iter().zip(clipboard_selections.iter().cycle()) - { - let to_insert = - String::from_iter(clipboard_chars.by_ref().take(clipboard_selection.len)); + for (i, selection) in selections.iter().enumerate() { + let to_insert; + let entire_line; + if let Some(clipboard_selection) = clipboard_selections.get(i) { + let end_offset = start_offset + clipboard_selection.len; + to_insert = &clipboard_text[start_offset..end_offset]; + entire_line = clipboard_selection.is_entire_line; + start_offset = end_offset + } else { + to_insert = clipboard_text.as_str(); + entire_line = all_selections_were_entire_line; + } self.buffer.update(cx, |buffer, cx| { let selection_start = selection.start.to_point(&*buffer); @@ -1246,7 +1249,7 @@ impl Editor { // selection was copied. If this selection is also currently empty, // then paste the line before the current line of the buffer. let new_selection_start = selection.end.bias_right(buffer); - if selection_start == selection_end && clipboard_selection.is_entire_line { + if selection_start == selection_end && entire_line { let line_start = Point::new(selection_start.row, 0); buffer.edit(Some(line_start..line_start), to_insert, cx); } else { @@ -1281,8 +1284,7 @@ impl Editor { pub fn move_left(&mut self, _: &MoveLeft, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let app = cx.as_ref(); - let mut selections = self.selections(app).to_vec(); + let mut selections = self.selections(cx).to_vec(); { for selection in &mut selections { let start = selection.start.to_display_point(&display_map, Bias::Left); @@ -1305,7 +1307,7 @@ impl Editor { pub fn select_left(&mut self, _: &SelectLeft, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { let buffer = self.buffer.read(cx); for selection in &mut selections { @@ -1321,7 +1323,7 @@ impl Editor { pub fn move_right(&mut self, _: &MoveRight, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { for selection in &mut selections { let start = selection.start.to_display_point(&display_map, Bias::Left); @@ -1344,7 +1346,7 @@ impl Editor { pub fn select_right(&mut self, _: &SelectRight, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { let app = cx.as_ref(); let buffer = self.buffer.read(app); @@ -1364,7 +1366,7 @@ impl Editor { if matches!(self.mode, EditorMode::SingleLine) { cx.propagate_action(); } else { - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { for selection in &mut selections { let start = selection.start.to_display_point(&display_map, Bias::Left); @@ -1387,7 +1389,7 @@ impl Editor { pub fn select_up(&mut self, _: &SelectUp, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { let app = cx.as_ref(); let buffer = self.buffer.read(app); @@ -1406,7 +1408,7 @@ impl Editor { cx.propagate_action(); } else { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let mut selections = self.selections(cx.as_ref()).to_vec(); + let mut selections = self.selections(cx).to_vec(); { for selection in &mut selections { let start = selection.start.to_display_point(&display_map, Bias::Left); @@ -1490,8 +1492,26 @@ impl Editor { cx: &mut ViewContext, ) { self.start_transaction(cx); - self.select_to_previous_word_boundary(&SelectToPreviousWordBoundary, cx); - self.backspace(&Backspace, cx); + let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); + let mut selections = self.selections(cx).to_vec(); + { + let buffer = self.buffer.read(cx); + for selection in &mut selections { + let range = selection.point_range(buffer); + if range.start == range.end { + let head = selection.head().to_display_point(&display_map, Bias::Left); + let cursor = display_map.anchor_before( + movement::prev_word_boundary(&display_map, head).unwrap(), + Bias::Right, + ); + selection.set_head(&buffer, cursor); + selection.goal = SelectionGoal::None; + } + } + } + + self.update_selections(selections, true, cx); + self.insert(&Insert(String::new()), cx); self.end_transaction(cx); } @@ -1542,8 +1562,26 @@ impl Editor { cx: &mut ViewContext, ) { self.start_transaction(cx); - self.select_to_next_word_boundary(&SelectToNextWordBoundary, cx); - self.delete(&Delete, cx); + let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); + let mut selections = self.selections(cx).to_vec(); + { + let buffer = self.buffer.read(cx); + for selection in &mut selections { + let range = selection.point_range(buffer); + if range.start == range.end { + let head = selection.head().to_display_point(&display_map, Bias::Left); + let cursor = display_map.anchor_before( + movement::next_word_boundary(&display_map, head).unwrap(), + Bias::Right, + ); + selection.set_head(&buffer, cursor); + selection.goal = SelectionGoal::None; + } + } + } + + self.update_selections(selections, true, cx); + self.insert(&Insert(String::new()), cx); self.end_transaction(cx); } @@ -1661,7 +1699,7 @@ impl Editor { } pub fn select_to_beginning(&mut self, _: &SelectToBeginning, cx: &mut ViewContext) { - let mut selection = self.selections(cx.as_ref()).last().unwrap().clone(); + let mut selection = self.selections(cx).last().unwrap().clone(); selection.set_head(self.buffer.read(cx), Anchor::min()); self.update_selections(vec![selection], true, cx); } @@ -1680,7 +1718,7 @@ impl Editor { } pub fn select_to_end(&mut self, _: &SelectToEnd, cx: &mut ViewContext) { - let mut selection = self.selections(cx.as_ref()).last().unwrap().clone(); + let mut selection = self.selections(cx).last().unwrap().clone(); selection.set_head(self.buffer.read(cx), Anchor::max()); self.update_selections(vec![selection], true, cx); } @@ -1698,8 +1736,8 @@ impl Editor { pub fn select_line(&mut self, _: &SelectLine, cx: &mut ViewContext) { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - let buffer = self.buffer.read(cx); let mut selections = self.selections(cx).to_vec(); + let buffer = self.buffer.read(cx); let max_point = buffer.max_point(); for selection in &mut selections { let rows = selection.spanned_rows(true, &display_map).buffer_rows; @@ -1715,12 +1753,12 @@ impl Editor { _: &SplitSelectionIntoLines, cx: &mut ViewContext, ) { - let app = cx.as_ref(); - let buffer = self.buffer.read(app); + let selections = self.selections(cx); + let buffer = self.buffer.read(cx); let mut to_unfold = Vec::new(); let mut new_selections = Vec::new(); - for selection in self.selections(app) { + for selection in selections.iter() { let range = selection.point_range(buffer).sorted(); if range.start.row != range.end.row { new_selections.push(Selection { @@ -1860,14 +1898,14 @@ impl Editor { _: &SelectLargerSyntaxNode, cx: &mut ViewContext, ) { + let old_selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let buffer = self.buffer.read(cx); let mut stack = mem::take(&mut self.select_larger_syntax_node_stack); let mut selected_larger_node = false; - let old_selections = self.selections(cx).to_vec(); let mut new_selection_ranges = Vec::new(); - for selection in &old_selections { + for selection in old_selections.iter() { let old_range = selection.start.to_offset(buffer)..selection.end.to_offset(buffer); let mut new_range = old_range.clone(); while let Some(containing_range) = buffer.range_for_syntax_ancestor(new_range.clone()) { @@ -1908,7 +1946,7 @@ impl Editor { ) { let mut stack = mem::take(&mut self.select_larger_syntax_node_stack); if let Some(selections) = stack.pop() { - self.update_selections(selections, true, cx); + self.update_selections(selections.to_vec(), true, cx); } self.select_larger_syntax_node_stack = stack; } @@ -1918,8 +1956,8 @@ impl Editor { _: &MoveToEnclosingBracket, cx: &mut ViewContext, ) { + let mut selections = self.selections(cx).to_vec(); let buffer = self.buffer.read(cx.as_ref()); - let mut selections = self.selections(cx.as_ref()).to_vec(); for selection in &mut selections { let selection_range = selection.offset_range(buffer); if let Some((open_range, close_range)) = @@ -2033,12 +2071,14 @@ impl Editor { } } - fn selections<'a>(&self, cx: &'a AppContext) -> &'a [Selection] { + fn selections(&mut self, cx: &mut ViewContext) -> Arc<[Selection]> { + self.end_selection(cx); let buffer = self.buffer.read(cx); - &buffer + buffer .selection_set(self.selection_set_id) .unwrap() .selections + .clone() } fn update_selections( @@ -2112,8 +2152,9 @@ impl Editor { pub fn fold(&mut self, _: &Fold, cx: &mut ViewContext) { let mut fold_ranges = Vec::new(); + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); - for selection in self.selections(cx) { + for selection in selections.iter() { let range = selection.display_range(&display_map).sorted(); let buffer_start_row = range.start.to_buffer_point(&display_map, Bias::Left).row; @@ -2134,10 +2175,10 @@ impl Editor { } pub fn unfold(&mut self, _: &Unfold, cx: &mut ViewContext) { + let selections = self.selections(cx); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let buffer = self.buffer.read(cx); - let ranges = self - .selections(cx) + let ranges = selections .iter() .map(|s| { let range = s.display_range(&display_map).sorted(); @@ -2195,9 +2236,9 @@ impl Editor { } pub fn fold_selected_ranges(&mut self, _: &FoldSelectedRanges, cx: &mut ViewContext) { + let selections = self.selections(cx); let buffer = self.buffer.read(cx); - let ranges = self - .selections(cx.as_ref()) + let ranges = selections .iter() .map(|s| s.point_range(buffer).sorted()) .collect(); @@ -3223,24 +3264,13 @@ mod tests { ); }); - view.update(cx, |view, cx| { - view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx); - assert_eq!( - view.selection_ranges(cx), - &[ - DisplayPoint::new(0, 3)..DisplayPoint::new(0, 3), - DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0), - ] - ); - }); - view.update(cx, |view, cx| { view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx); assert_eq!( view.selection_ranges(cx), &[ DisplayPoint::new(0, 0)..DisplayPoint::new(0, 0), - DisplayPoint::new(0, 24)..DisplayPoint::new(0, 24), + DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0), ] ); }); @@ -3267,24 +3297,13 @@ mod tests { ); }); - view.update(cx, |view, cx| { - view.move_to_next_word_boundary(&MoveToNextWordBoundary, cx); - assert_eq!( - view.selection_ranges(cx), - &[ - DisplayPoint::new(0, 4)..DisplayPoint::new(0, 4), - DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0), - ] - ); - }); - view.update(cx, |view, cx| { view.move_to_next_word_boundary(&MoveToNextWordBoundary, cx); assert_eq!( view.selection_ranges(cx), &[ DisplayPoint::new(0, 7)..DisplayPoint::new(0, 7), - DisplayPoint::new(2, 0)..DisplayPoint::new(2, 0), + DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0), ] ); }); @@ -3295,7 +3314,7 @@ mod tests { view.selection_ranges(cx), &[ DisplayPoint::new(0, 9)..DisplayPoint::new(0, 9), - DisplayPoint::new(2, 2)..DisplayPoint::new(2, 2), + DisplayPoint::new(2, 3)..DisplayPoint::new(2, 3), ] ); }); @@ -3307,7 +3326,7 @@ mod tests { view.selection_ranges(cx), &[ DisplayPoint::new(0, 10)..DisplayPoint::new(0, 9), - DisplayPoint::new(2, 3)..DisplayPoint::new(2, 2), + DisplayPoint::new(2, 4)..DisplayPoint::new(2, 3), ] ); }); @@ -3318,7 +3337,7 @@ mod tests { view.selection_ranges(cx), &[ DisplayPoint::new(0, 10)..DisplayPoint::new(0, 7), - DisplayPoint::new(2, 3)..DisplayPoint::new(2, 0), + DisplayPoint::new(2, 4)..DisplayPoint::new(2, 2), ] ); }); @@ -3329,37 +3348,7 @@ mod tests { view.selection_ranges(cx), &[ DisplayPoint::new(0, 10)..DisplayPoint::new(0, 9), - DisplayPoint::new(2, 3)..DisplayPoint::new(2, 2), - ] - ); - }); - - view.update(cx, |view, cx| { - view.delete_to_next_word_boundary(&DeleteToNextWordBoundary, cx); - assert_eq!( - view.display_text(cx), - "use std::s::{foo, bar}\n\n {az.qux()}" - ); - assert_eq!( - view.selection_ranges(cx), - &[ - DisplayPoint::new(0, 10)..DisplayPoint::new(0, 10), - DisplayPoint::new(2, 3)..DisplayPoint::new(2, 3), - ] - ); - }); - - view.update(cx, |view, cx| { - view.delete_to_previous_word_boundary(&DeleteToPreviousWordBoundary, cx); - assert_eq!( - view.display_text(cx), - "use std::::{foo, bar}\n\n az.qux()}" - ); - assert_eq!( - view.selection_ranges(cx), - &[ - DisplayPoint::new(0, 9)..DisplayPoint::new(0, 9), - DisplayPoint::new(2, 2)..DisplayPoint::new(2, 2), + DisplayPoint::new(2, 4)..DisplayPoint::new(2, 3), ] ); }); @@ -3415,11 +3404,52 @@ mod tests { view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx); assert_eq!( view.selection_ranges(cx), - &[DisplayPoint::new(1, 15)..DisplayPoint::new(1, 15)] + &[DisplayPoint::new(1, 14)..DisplayPoint::new(1, 14)] ); }); } + #[gpui::test] + fn test_delete_to_word_boundary(cx: &mut gpui::MutableAppContext) { + let buffer = cx.add_model(|cx| Buffer::new(0, "one two three four", cx)); + let settings = settings::test(&cx).1; + let (_, view) = cx.add_window(Default::default(), |cx| { + build_editor(buffer.clone(), settings, cx) + }); + + view.update(cx, |view, cx| { + view.select_display_ranges( + &[ + // an empty selection - the preceding word fragment is deleted + DisplayPoint::new(0, 2)..DisplayPoint::new(0, 2), + // characters selected - they are deleted + DisplayPoint::new(0, 9)..DisplayPoint::new(0, 12), + ], + cx, + ) + .unwrap(); + view.delete_to_previous_word_boundary(&DeleteToPreviousWordBoundary, cx); + }); + + assert_eq!(buffer.read(cx).text(), "e two te four"); + + view.update(cx, |view, cx| { + view.select_display_ranges( + &[ + // an empty selection - the following word fragment is deleted + DisplayPoint::new(0, 3)..DisplayPoint::new(0, 3), + // characters selected - they are deleted + DisplayPoint::new(0, 9)..DisplayPoint::new(0, 10), + ], + cx, + ) + .unwrap(); + view.delete_to_next_word_boundary(&DeleteToNextWordBoundary, cx); + }); + + assert_eq!(buffer.read(cx).text(), "e t te our"); + } + #[gpui::test] fn test_backspace(cx: &mut gpui::MutableAppContext) { let buffer = cx.add_model(|cx| { @@ -3685,7 +3715,7 @@ mod tests { #[gpui::test] fn test_clipboard(cx: &mut gpui::MutableAppContext) { - let buffer = cx.add_model(|cx| Buffer::new(0, "one two three four five six ", cx)); + let buffer = cx.add_model(|cx| Buffer::new(0, "one✅ two three four five six ", cx)); let settings = settings::test(&cx).1; let view = cx .add_window(Default::default(), |cx| { @@ -3695,7 +3725,7 @@ mod tests { // Cut with three selections. Clipboard text is divided into three slices. view.update(cx, |view, cx| { - view.select_ranges(vec![0..4, 8..14, 19..24], false, cx); + view.select_ranges(vec![0..7, 11..17, 22..27], false, cx); view.cut(&Cut, cx); assert_eq!(view.display_text(cx), "two four six "); }); @@ -3704,13 +3734,13 @@ mod tests { view.update(cx, |view, cx| { view.select_ranges(vec![4..4, 9..9, 13..13], false, cx); view.paste(&Paste, cx); - assert_eq!(view.display_text(cx), "two one four three six five "); + assert_eq!(view.display_text(cx), "two one✅ four three six five "); assert_eq!( view.selection_ranges(cx), &[ - DisplayPoint::new(0, 8)..DisplayPoint::new(0, 8), - DisplayPoint::new(0, 19)..DisplayPoint::new(0, 19), - DisplayPoint::new(0, 28)..DisplayPoint::new(0, 28) + DisplayPoint::new(0, 11)..DisplayPoint::new(0, 11), + DisplayPoint::new(0, 22)..DisplayPoint::new(0, 22), + DisplayPoint::new(0, 31)..DisplayPoint::new(0, 31) ] ); }); @@ -3719,13 +3749,13 @@ mod tests { // match the number of slices in the clipboard, the entire clipboard text // is pasted at each cursor. view.update(cx, |view, cx| { - view.select_ranges(vec![0..0, 28..28], false, cx); + view.select_ranges(vec![0..0, 31..31], false, cx); view.insert(&Insert("( ".into()), cx); view.paste(&Paste, cx); view.insert(&Insert(") ".into()), cx); assert_eq!( view.display_text(cx), - "( one three five ) two one four three six five ( one three five ) " + "( one✅ three five ) two one✅ four three six five ( one✅ three five ) " ); }); @@ -3734,7 +3764,7 @@ mod tests { view.insert(&Insert("123\n4567\n89\n".into()), cx); assert_eq!( view.display_text(cx), - "123\n4567\n89\n( one three five ) two one four three six five ( one three five ) " + "123\n4567\n89\n( one✅ three five ) two one✅ four three six five ( one✅ three five ) " ); }); @@ -3752,7 +3782,7 @@ mod tests { view.cut(&Cut, cx); assert_eq!( view.display_text(cx), - "13\n9\n( one three five ) two one four three six five ( one three five ) " + "13\n9\n( one✅ three five ) two one✅ four three six five ( one✅ three five ) " ); }); @@ -3771,7 +3801,7 @@ mod tests { view.paste(&Paste, cx); assert_eq!( view.display_text(cx), - "123\n4567\n9\n( 8ne three five ) two one four three six five ( one three five ) " + "123\n4567\n9\n( 8ne✅ three five ) two one✅ four three six five ( one✅ three five ) " ); assert_eq!( view.selection_ranges(cx), @@ -3805,7 +3835,7 @@ mod tests { view.paste(&Paste, cx); assert_eq!( view.display_text(cx), - "123\n123\n123\n67\n123\n9\n( 8ne three five ) two one four three six five ( one three five ) " + "123\n123\n123\n67\n123\n9\n( 8ne✅ three five ) two one✅ four three six five ( one✅ three five ) " ); assert_eq!( view.selection_ranges(cx), diff --git a/zed/src/editor/movement.rs b/zed/src/editor/movement.rs index 8f5bc6f20a814536366f9077f1509d611dc1cf29..d86aa9ca53b9332ca552083c1bce889e8605963a 100644 --- a/zed/src/editor/movement.rs +++ b/zed/src/editor/movement.rs @@ -101,7 +101,10 @@ pub fn line_end(map: &DisplayMapSnapshot, point: DisplayPoint) -> Result Result { +pub fn prev_word_boundary( + map: &DisplayMapSnapshot, + mut point: DisplayPoint, +) -> Result { let mut line_start = 0; if point.row() > 0 { if let Some(indent) = map.soft_wrap_indent(point.row() - 1) { @@ -111,39 +114,52 @@ pub fn prev_word_boundary(map: &DisplayMapSnapshot, point: DisplayPoint) -> Resu if point.column() == line_start { if point.row() == 0 { - Ok(DisplayPoint::new(0, 0)) + return Ok(DisplayPoint::new(0, 0)); } else { let row = point.row() - 1; - Ok(map.clip_point(DisplayPoint::new(row, map.line_len(row)), Bias::Left)) + point = map.clip_point(DisplayPoint::new(row, map.line_len(row)), Bias::Left); } - } else { - let mut boundary = DisplayPoint::new(point.row(), 0); - let mut column = 0; - let mut prev_c = None; - for c in map.chars_at(DisplayPoint::new(point.row(), 0)) { - if column >= point.column() { - break; - } + } - if prev_c.is_none() || char_kind(prev_c.unwrap()) != char_kind(c) { - *boundary.column_mut() = column; - } + let mut boundary = DisplayPoint::new(point.row(), 0); + let mut column = 0; + let mut prev_char_kind = CharKind::Newline; + for c in map.chars_at(DisplayPoint::new(point.row(), 0)) { + if column >= point.column() { + break; + } - prev_c = Some(c); - column += c.len_utf8() as u32; + let char_kind = char_kind(c); + if char_kind != prev_char_kind + && char_kind != CharKind::Whitespace + && char_kind != CharKind::Newline + { + *boundary.column_mut() = column; } - Ok(boundary) + + prev_char_kind = char_kind; + column += c.len_utf8() as u32; } + Ok(boundary) } pub fn next_word_boundary( map: &DisplayMapSnapshot, mut point: DisplayPoint, ) -> Result { - let mut prev_c = None; + let mut prev_char_kind = None; for c in map.chars_at(point) { - if prev_c.is_some() && (c == '\n' || char_kind(prev_c.unwrap()) != char_kind(c)) { - break; + let char_kind = char_kind(c); + if let Some(prev_char_kind) = prev_char_kind { + if c == '\n' { + break; + } + if prev_char_kind != char_kind + && prev_char_kind != CharKind::Whitespace + && prev_char_kind != CharKind::Newline + { + break; + } } if c == '\n' { @@ -152,7 +168,7 @@ pub fn next_word_boundary( } else { *point.column_mut() += c.len_utf8() as u32; } - prev_c = Some(c); + prev_char_kind = Some(char_kind); } Ok(point) } @@ -192,7 +208,7 @@ mod tests { .unwrap(); let font_size = 14.0; - let buffer = cx.add_model(|cx| Buffer::new(0, "a bcΔ defγ", cx)); + let buffer = cx.add_model(|cx| Buffer::new(0, "a bcΔ defγ hi—jk", cx)); let display_map = cx.add_model(|cx| DisplayMap::new(buffer, tab_size, font_id, font_size, None, cx)); let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx)); @@ -202,7 +218,7 @@ mod tests { ); assert_eq!( prev_word_boundary(&snapshot, DisplayPoint::new(0, 7)).unwrap(), - DisplayPoint::new(0, 6) + DisplayPoint::new(0, 2) ); assert_eq!( prev_word_boundary(&snapshot, DisplayPoint::new(0, 6)).unwrap(), @@ -210,7 +226,7 @@ mod tests { ); assert_eq!( prev_word_boundary(&snapshot, DisplayPoint::new(0, 2)).unwrap(), - DisplayPoint::new(0, 1) + DisplayPoint::new(0, 0) ); assert_eq!( prev_word_boundary(&snapshot, DisplayPoint::new(0, 1)).unwrap(), @@ -223,7 +239,7 @@ mod tests { ); assert_eq!( next_word_boundary(&snapshot, DisplayPoint::new(0, 1)).unwrap(), - DisplayPoint::new(0, 2) + DisplayPoint::new(0, 6) ); assert_eq!( next_word_boundary(&snapshot, DisplayPoint::new(0, 2)).unwrap(), @@ -231,7 +247,7 @@ mod tests { ); assert_eq!( next_word_boundary(&snapshot, DisplayPoint::new(0, 6)).unwrap(), - DisplayPoint::new(0, 7) + DisplayPoint::new(0, 12) ); assert_eq!( next_word_boundary(&snapshot, DisplayPoint::new(0, 7)).unwrap(), diff --git a/zed/src/file_finder.rs b/zed/src/file_finder.rs index 8f3217b2e58521b6dcb9f45ac67e992d90c9b420..2749bbe59d7e3dd9d4132b1d3654a01fb91fe50a 100644 --- a/zed/src/file_finder.rs +++ b/zed/src/file_finder.rs @@ -438,29 +438,37 @@ mod tests { use crate::{ editor::{self, Insert}, fs::FakeFs, - test::{temp_tree, test_app_state}, + test::test_app_state, workspace::Workspace, }; use serde_json::json; - use std::fs; - use tempdir::TempDir; + use std::path::PathBuf; #[gpui::test] async fn test_matching_paths(mut cx: gpui::TestAppContext) { - let tmp_dir = TempDir::new("example").unwrap(); - fs::create_dir(tmp_dir.path().join("a")).unwrap(); - fs::write(tmp_dir.path().join("a/banana"), "banana").unwrap(); - fs::write(tmp_dir.path().join("a/bandana"), "bandana").unwrap(); + let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree( + "/root", + json!({ + "a": { + "banana": "", + "bandana": "", + } + }), + ) + .await; cx.update(|cx| { super::init(cx); editor::init(cx); }); - let app_state = cx.update(test_app_state); let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(tmp_dir.path(), cx) + workspace.add_worktree(Path::new("/root"), cx) }) .await .unwrap(); @@ -572,17 +580,17 @@ mod tests { #[gpui::test] async fn test_single_file_worktrees(mut cx: gpui::TestAppContext) { - let temp_dir = TempDir::new("test-single-file-worktrees").unwrap(); - let dir_path = temp_dir.path().join("the-parent-dir"); - let file_path = dir_path.join("the-file"); - fs::create_dir(&dir_path).unwrap(); - fs::write(&file_path, "").unwrap(); - let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree("/root", json!({ "the-parent-dir": { "the-file": "" } })) + .await; + let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(&file_path, cx) + workspace.add_worktree(Path::new("/root/the-parent-dir/the-file"), cx) }) .await .unwrap(); @@ -620,18 +628,25 @@ mod tests { #[gpui::test(retries = 5)] async fn test_multiple_matches_with_same_relative_path(mut cx: gpui::TestAppContext) { - let tmp_dir = temp_tree(json!({ - "dir1": { "a.txt": "" }, - "dir2": { "a.txt": "" } - })); - let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree( + "/root", + json!({ + "dir1": { "a.txt": "" }, + "dir2": { "a.txt": "" } + }), + ) + .await; + let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { workspace.open_paths( - &[tmp_dir.path().join("dir1"), tmp_dir.path().join("dir2")], + &[PathBuf::from("/root/dir1"), PathBuf::from("/root/dir2")], cx, ) }) diff --git a/zed/src/fs.rs b/zed/src/fs.rs index e9d6d9230be0619577a57d4af2bae225b4d44927..7f67d6ef0225249efbaa3262b2a3e4d413a9ff96 100644 --- a/zed/src/fs.rs +++ b/zed/src/fs.rs @@ -29,6 +29,8 @@ pub trait Fs: Send + Sync { latency: Duration, ) -> Pin>>>; fn is_fake(&self) -> bool; + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs; } #[derive(Clone, Debug)] @@ -125,6 +127,11 @@ impl Fs for RealFs { fn is_fake(&self) -> bool { false } + + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs { + panic!("called `RealFs::as_fake`") + } } #[derive(Clone, Debug)] @@ -413,4 +420,9 @@ impl Fs for FakeFs { fn is_fake(&self) -> bool { true } + + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs { + self + } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 5dc2b49b76d9ac8ca958f6dc3903a8da3603dbcd..86f484607d6a2d15f9414f7ddaf66964b7ab4510 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -55,6 +55,8 @@ pub struct Client { #[derive(Error, Debug)] pub enum EstablishConnectionError { + #[error("upgrade required")] + UpgradeRequired, #[error("unauthorized")] Unauthorized, #[error("{0}")] @@ -68,8 +70,10 @@ pub enum EstablishConnectionError { impl From for EstablishConnectionError { fn from(error: WebsocketError) -> Self { if let WebsocketError::Http(response) = &error { - if response.status() == StatusCode::UNAUTHORIZED { - return EstablishConnectionError::Unauthorized; + match response.status() { + StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized, + StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired, + _ => {} } } EstablishConnectionError::Other(error.into()) @@ -85,6 +89,7 @@ impl EstablishConnectionError { #[derive(Copy, Clone, Debug)] pub enum Status { SignedOut, + UpgradeRequired, Authenticating, Connecting, ConnectionError, @@ -227,7 +232,7 @@ impl Client { } })); } - Status::SignedOut => { + Status::SignedOut | Status::UpgradeRequired => { state._maintain_connection.take(); } _ => {} @@ -346,6 +351,7 @@ impl Client { | Status::Reconnecting { .. } | Status::Authenticating | Status::Reauthenticating => return Ok(()), + Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, }; if was_disconnected { @@ -388,22 +394,25 @@ impl Client { self.set_connection(conn, cx).await; Ok(()) } - Err(err) => { - if matches!(err, EstablishConnectionError::Unauthorized) { - self.state.write().credentials.take(); - if used_keychain { - cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(cx).await - } else { - self.set_status(Status::ConnectionError, cx); - Err(err)? - } + Err(EstablishConnectionError::Unauthorized) => { + self.state.write().credentials.take(); + if used_keychain { + cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); + self.set_status(Status::SignedOut, cx); + self.authenticate_and_connect(cx).await } else { self.set_status(Status::ConnectionError, cx); - Err(err)? + Err(EstablishConnectionError::Unauthorized)? } } + Err(EstablishConnectionError::UpgradeRequired) => { + self.set_status(Status::UpgradeRequired, cx); + Err(EstablishConnectionError::UpgradeRequired)? + } + Err(error) => { + self.set_status(Status::ConnectionError, cx); + Err(error)? + } } } @@ -489,10 +498,12 @@ impl Client { credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - let request = Request::builder().header( - "Authorization", - format!("{} {}", credentials.user_id, credentials.access_token), - ); + let request = Request::builder() + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .header("X-Zed-Protocol-Version", zrpc::PROTOCOL_VERSION); cx.background().spawn(async move { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { let stream = smol::net::TcpStream::connect(host).await?; diff --git a/zed/src/test.rs b/zed/src/test.rs index eb6bf20f456236a5158c3a5e8d6a9311a7eddc40..59edd380d599007db32d964b9dbe3b2d119e53bb 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -1,7 +1,7 @@ use crate::{ assets::Assets, channel::ChannelList, - fs::RealFs, + fs::FakeFs, http::{HttpClient, Request, Response, ServerResponse}, language::LanguageRegistry, rpc::{self, Client, Credentials, EstablishConnectionError}, @@ -177,7 +177,7 @@ pub fn test_app_state(cx: &mut MutableAppContext) -> Arc { channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)), rpc, user_store, - fs: Arc::new(RealFs), + fs: Arc::new(FakeFs::new()), }) } diff --git a/zed/src/theme.rs b/zed/src/theme.rs index 8b43b09f13d28fdd1cafdf0871a12123efa12f96..a5378fe033c2f70c7767b75a7fe9cc6ec434b47d 100644 --- a/zed/src/theme.rs +++ b/zed/src/theme.rs @@ -54,6 +54,7 @@ pub struct Titlebar { pub offline_icon: OfflineIcon, pub icon_color: Color, pub avatar: ImageStyle, + pub outdated_warning: ContainedText, } #[derive(Clone, Deserialize)] @@ -169,7 +170,7 @@ pub struct Selector { pub active_item: ContainedLabel, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct ContainedText { #[serde(flatten)] pub container: ContainerStyle, diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index df4afb020e3246d09f462fb7ad85cd7fe72d13a2..456b9e7042ab648be9e895d71f518718facba654 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -1041,6 +1041,16 @@ impl Workspace { .with_style(theme.workspace.titlebar.offline_icon.container) .boxed(), ), + rpc::Status::UpgradeRequired => Some( + Label::new( + "Please update Zed to collaborate".to_string(), + theme.workspace.titlebar.outdated_warning.text.clone(), + ) + .contained() + .with_style(theme.workspace.titlebar.outdated_warning.container) + .aligned() + .boxed(), + ), _ => None, } } @@ -1195,11 +1205,9 @@ mod tests { editor::{Editor, Insert}, fs::FakeFs, test::{temp_tree, test_app_state}, - worktree::WorktreeHandle, }; use serde_json::json; - use std::{collections::HashSet, fs}; - use tempdir::TempDir; + use std::collections::HashSet; #[gpui::test] async fn test_open_paths_action(mut cx: gpui::TestAppContext) { @@ -1268,20 +1276,26 @@ mod tests { #[gpui::test] async fn test_open_entry(mut cx: gpui::TestAppContext) { - let dir = temp_tree(json!({ - "a": { - "file1": "contents 1", - "file2": "contents 2", - "file3": "contents 3", - }, - })); - let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree( + "/root", + json!({ + "a": { + "file1": "contents 1", + "file2": "contents 2", + "file3": "contents 3", + }, + }), + ) + .await; let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(dir.path(), cx) + workspace.add_worktree(Path::new("/root"), cx) }) .await .unwrap(); @@ -1445,28 +1459,30 @@ mod tests { #[gpui::test] async fn test_save_conflicting_item(mut cx: gpui::TestAppContext) { - let dir = temp_tree(json!({ - "a.txt": "", - })); - let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree( + "/root", + json!({ + "a.txt": "", + }), + ) + .await; + let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(dir.path(), cx) + workspace.add_worktree(Path::new("/root"), cx) }) .await .unwrap(); - let tree = cx.read(|cx| { - let mut trees = workspace.read(cx).worktrees().iter(); - trees.next().unwrap().clone() - }); - tree.flush_fs_events(&cx).await; // Open a file within an existing worktree. cx.update(|cx| { workspace.update(cx, |view, cx| { - view.open_paths(&[dir.path().join("a.txt")], cx) + view.open_paths(&[PathBuf::from("/root/a.txt")], cx) }) }) .await; @@ -1477,7 +1493,12 @@ mod tests { }); cx.update(|cx| editor.update(cx, |editor, cx| editor.insert(&Insert("x".into()), cx))); - fs::write(dir.path().join("a.txt"), "changed").unwrap(); + app_state + .fs + .as_fake() + .insert_file("/root/a.txt", "changed".to_string()) + .await + .unwrap(); editor .condition(&cx, |editor, cx| editor.has_conflict(cx)) .await; @@ -1493,12 +1514,12 @@ mod tests { #[gpui::test] async fn test_open_and_save_new_file(mut cx: gpui::TestAppContext) { - let dir = TempDir::new("test-new-file").unwrap(); let app_state = cx.update(test_app_state); + app_state.fs.as_fake().insert_dir("/root").await.unwrap(); let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(dir.path(), cx) + workspace.add_worktree(Path::new("/root"), cx) }) .await .unwrap(); @@ -1511,7 +1532,6 @@ mod tests { .unwrap() .clone() }); - tree.flush_fs_events(&cx).await; // Create a new untitled buffer let editor = workspace.update(&mut cx, |workspace, cx| { @@ -1537,7 +1557,7 @@ mod tests { workspace.save_active_item(&Save, cx) }); cx.simulate_new_path_selection(|parent_dir| { - assert_eq!(parent_dir, dir.path()); + assert_eq!(parent_dir, Path::new("/root")); Some(parent_dir.join("the-new-name.rs")) }); cx.read(|cx| { @@ -1598,8 +1618,8 @@ mod tests { async fn test_setting_language_when_saving_as_single_file_worktree( mut cx: gpui::TestAppContext, ) { - let dir = TempDir::new("test-new-file").unwrap(); let app_state = cx.update(test_app_state); + app_state.fs.as_fake().insert_dir("/root").await.unwrap(); let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); // Create a new untitled buffer @@ -1623,7 +1643,7 @@ mod tests { workspace.update(&mut cx, |workspace, cx| { workspace.save_active_item(&Save, cx) }); - cx.simulate_new_path_selection(|_| Some(dir.path().join("the-new-name.rs"))); + cx.simulate_new_path_selection(|_| Some(PathBuf::from("/root/the-new-name.rs"))); editor .condition(&cx, |editor, cx| !editor.is_dirty(cx)) @@ -1640,7 +1660,7 @@ mod tests { cx.update(init); let app_state = cx.update(test_app_state); - cx.dispatch_global_action(OpenNew(app_state)); + cx.dispatch_global_action(OpenNew(app_state.clone())); let window_id = *cx.window_ids().first().unwrap(); let workspace = cx.root_view::(window_id).unwrap(); let editor = workspace.update(&mut cx, |workspace, cx| { @@ -1660,10 +1680,8 @@ mod tests { workspace.save_active_item(&Save, cx) }); - let dir = TempDir::new("test-new-empty-workspace").unwrap(); - cx.simulate_new_path_selection(|_| { - Some(dir.path().canonicalize().unwrap().join("the-new-name")) - }); + app_state.fs.as_fake().insert_dir("/root").await.unwrap(); + cx.simulate_new_path_selection(|_| Some(PathBuf::from("/root/the-new-name"))); editor .condition(&cx, |editor, cx| editor.title(cx) == "the-new-name") @@ -1676,20 +1694,26 @@ mod tests { #[gpui::test] async fn test_pane_actions(mut cx: gpui::TestAppContext) { cx.update(|cx| pane::init(cx)); - - let dir = temp_tree(json!({ - "a": { - "file1": "contents 1", - "file2": "contents 2", - "file3": "contents 3", - }, - })); - let app_state = cx.update(test_app_state); + app_state + .fs + .as_fake() + .insert_tree( + "/root", + json!({ + "a": { + "file1": "contents 1", + "file2": "contents 2", + "file3": "contents 3", + }, + }), + ) + .await; + let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx)); workspace .update(&mut cx, |workspace, cx| { - workspace.add_worktree(dir.path(), cx) + workspace.add_worktree(Path::new("/root"), cx) }) .await .unwrap(); diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index de819af1411ca3abaef9604f9e4b4ad6c09d41e1..15cfd29463747ccc4e0b25b50f7cb6fc014f9cb2 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -161,6 +161,7 @@ impl Worktree { entries_by_id_edits.push(Edit::Insert(PathEntry { id: entry.id, path: entry.path.clone(), + is_ignored: entry.is_ignored, scan_id: 0, })); entries_by_path_edits.push(Edit::Insert(entry)); @@ -1083,7 +1084,7 @@ impl LocalWorktree { async move { let mut prev_snapshot = snapshot; while let Ok(snapshot) = snapshots_to_send_rx.recv().await { - let message = snapshot.build_update(&prev_snapshot, remote_id); + let message = snapshot.build_update(&prev_snapshot, remote_id, false); match rpc.send(message).await { Ok(()) => prev_snapshot = snapshot, Err(err) => log::error!("error sending snapshot diff {}", err), @@ -1140,6 +1141,7 @@ impl LocalWorktree { let entries = snapshot .entries_by_path .cursor::<(), ()>() + .filter(|e| !e.is_ignored) .map(Into::into) .collect(); proto::ShareWorktree { @@ -1373,11 +1375,24 @@ impl Snapshot { self.id } - pub fn build_update(&self, other: &Self, worktree_id: u64) -> proto::UpdateWorktree { + pub fn build_update( + &self, + other: &Self, + worktree_id: u64, + include_ignored: bool, + ) -> proto::UpdateWorktree { let mut updated_entries = Vec::new(); let mut removed_entries = Vec::new(); - let mut self_entries = self.entries_by_id.cursor::<(), ()>().peekable(); - let mut other_entries = other.entries_by_id.cursor::<(), ()>().peekable(); + let mut self_entries = self + .entries_by_id + .cursor::<(), ()>() + .filter(|e| include_ignored || !e.is_ignored) + .peekable(); + let mut other_entries = other + .entries_by_id + .cursor::<(), ()>() + .filter(|e| include_ignored || !e.is_ignored) + .peekable(); loop { match (self_entries.peek(), other_entries.peek()) { (Some(self_entry), Some(other_entry)) => match self_entry.id.cmp(&other_entry.id) { @@ -1443,6 +1458,7 @@ impl Snapshot { entries_by_id_edits.push(Edit::Insert(PathEntry { id: entry.id, path: entry.path.clone(), + is_ignored: entry.is_ignored, scan_id, })); entries_by_path_edits.push(Edit::Insert(entry)); @@ -1526,6 +1542,7 @@ impl Snapshot { PathEntry { id: entry.id, path: entry.path.clone(), + is_ignored: entry.is_ignored, scan_id: self.scan_id, }, &(), @@ -1561,6 +1578,7 @@ impl Snapshot { entries_by_id_edits.push(Edit::Insert(PathEntry { id: entry.id, path: entry.path.clone(), + is_ignored: entry.is_ignored, scan_id: self.scan_id, })); entries_by_path_edits.push(Edit::Insert(entry)); @@ -1933,6 +1951,7 @@ impl sum_tree::Summary for EntrySummary { struct PathEntry { id: usize, path: Arc, + is_ignored: bool, scan_id: usize, } @@ -2412,7 +2431,8 @@ impl BackgroundScanner { ignore_stack = ignore_stack.append(job.path.clone(), ignore.clone()); } - let mut edits = Vec::new(); + let mut entries_by_id_edits = Vec::new(); + let mut entries_by_path_edits = Vec::new(); for mut entry in snapshot.child_entries(&job.path).cloned() { let was_ignored = entry.is_ignored; entry.is_ignored = ignore_stack.is_path_ignored(&entry.path, entry.is_dir()); @@ -2433,10 +2453,17 @@ impl BackgroundScanner { } if entry.is_ignored != was_ignored { - edits.push(Edit::Insert(entry)); + let mut path_entry = snapshot.entries_by_id.get(&entry.id, &()).unwrap().clone(); + path_entry.scan_id = snapshot.scan_id; + path_entry.is_ignored = entry.is_ignored; + entries_by_id_edits.push(Edit::Insert(path_entry)); + entries_by_path_edits.push(Edit::Insert(entry)); } } - self.snapshot.lock().entries_by_path.edit(edits, &()); + + let mut snapshot = self.snapshot.lock(); + snapshot.entries_by_path.edit(entries_by_path_edits, &()); + snapshot.entries_by_id.edit(entries_by_id_edits, &()); } } @@ -3000,10 +3027,10 @@ mod tests { // Update the remote worktree. Check that it becomes consistent with the // local worktree. remote.update(&mut cx, |remote, cx| { - let update_message = tree - .read(cx) - .snapshot() - .build_update(&initial_snapshot, worktree_id); + let update_message = + tree.read(cx) + .snapshot() + .build_update(&initial_snapshot, worktree_id, true); remote .as_remote_mut() .unwrap() @@ -3175,6 +3202,7 @@ mod tests { scanner.snapshot().check_invariants(); let mut events = Vec::new(); + let mut snapshots = Vec::new(); let mut mutations_len = operations; while mutations_len > 1 { if !events.is_empty() && rng.gen_bool(0.4) { @@ -3187,6 +3215,10 @@ mod tests { events.extend(randomly_mutate_tree(root_dir.path(), 0.6, &mut rng).unwrap()); mutations_len -= 1; } + + if rng.gen_bool(0.2) { + snapshots.push(scanner.snapshot()); + } } log::info!("Quiescing: {:#?}", events); smol::block_on(scanner.process_events(events)); @@ -3200,7 +3232,40 @@ mod tests { scanner.executor.clone(), ); smol::block_on(new_scanner.scan_dirs()).unwrap(); - assert_eq!(scanner.snapshot().to_vec(), new_scanner.snapshot().to_vec()); + assert_eq!( + scanner.snapshot().to_vec(true), + new_scanner.snapshot().to_vec(true) + ); + + for mut prev_snapshot in snapshots { + let include_ignored = rng.gen::(); + if !include_ignored { + let mut entries_by_path_edits = Vec::new(); + let mut entries_by_id_edits = Vec::new(); + for entry in prev_snapshot + .entries_by_id + .cursor::<(), ()>() + .filter(|e| e.is_ignored) + { + entries_by_path_edits.push(Edit::Remove(PathKey(entry.path.clone()))); + entries_by_id_edits.push(Edit::Remove(entry.id)); + } + + prev_snapshot + .entries_by_path + .edit(entries_by_path_edits, &()); + prev_snapshot.entries_by_id.edit(entries_by_id_edits, &()); + } + + let update = scanner + .snapshot() + .build_update(&prev_snapshot, 0, include_ignored); + prev_snapshot.apply_update(update).unwrap(); + assert_eq!( + prev_snapshot.to_vec(true), + scanner.snapshot().to_vec(include_ignored) + ); + } } fn randomly_mutate_tree( @@ -3390,10 +3455,12 @@ mod tests { } } - fn to_vec(&self) -> Vec<(&Path, u64, bool)> { + fn to_vec(&self, include_ignored: bool) -> Vec<(&Path, u64, bool)> { let mut paths = Vec::new(); for entry in self.entries_by_path.cursor::<(), ()>() { - paths.push((entry.path.as_ref(), entry.inode, entry.is_ignored)); + if include_ignored || !entry.is_ignored { + paths.push((entry.path.as_ref(), entry.inode, entry.is_ignored)); + } } paths.sort_by(|a, b| a.0.cmp(&b.0)); paths diff --git a/zrpc/Cargo.toml b/zrpc/Cargo.toml index 5d78f46bb13591ccde8ea93633ea386026596c62..ee1b1433daad18c1f77a98c1b2db554173b399bd 100644 --- a/zrpc/Cargo.toml +++ b/zrpc/Cargo.toml @@ -20,6 +20,7 @@ prost = "0.7" rand = "0.8" rsa = "0.4" serde = { version = "1", features = ["derive"] } +zstd = "0.9" [build-dependencies] prost-build = { git = "https://github.com/tokio-rs/prost", rev = "6cf97ea422b09d98de34643c4dda2d4f8b7e23e6" } diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index a7bb44774b8e700443f753e3fb47c1176ef80142..ccaf50135003576fd98c43cf26aff4e5202e7621 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -4,3 +4,5 @@ mod peer; pub mod proto; pub use conn::Connection; pub use peer::*; + +pub const PROTOCOL_VERSION: u32 = 1; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index eeda034e9581ce215ee01821cff3e82bab70ed25..251ffb5bb512e2a603b57922b9097edbd408fecc 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -87,7 +87,7 @@ pub struct Peer { struct ConnectionState { outgoing_tx: mpsc::Sender, next_message_id: Arc, - response_channels: Arc>>>, + response_channels: Arc>>>>, } impl Peer { @@ -115,7 +115,7 @@ impl Peer { let connection_state = ConnectionState { outgoing_tx, next_message_id: Default::default(), - response_channels: Default::default(), + response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; let mut writer = MessageStream::new(connection.tx); let mut reader = MessageStream::new(connection.rx); @@ -123,7 +123,7 @@ impl Peer { let this = self.clone(); let response_channels = connection_state.response_channels.clone(); let handle_io = async move { - loop { + let result = 'outer: loop { let read_message = reader.read_message().fuse(); futures::pin_mut!(read_message); loop { @@ -131,7 +131,7 @@ impl Peer { incoming = read_message => match incoming { Ok(incoming) => { if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); + let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to); if let Some(mut tx) = channel { tx.send(incoming).await.ok(); } else { @@ -140,9 +140,7 @@ impl Peer { } else { if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { if incoming_tx.send(envelope).await.is_err() { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - return Ok(()) + break 'outer Ok(()) } } else { log::error!("unable to construct a typed envelope"); @@ -152,28 +150,24 @@ impl Peer { break; } Err(error) => { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - Err(error).context("received invalid RPC message")?; + break 'outer Err(error).context("received invalid RPC message") } }, outgoing = outgoing_rx.recv().fuse() => match outgoing { Some(outgoing) => { if let Err(result) = writer.write_message(&outgoing).await { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - Err(result).context("failed to write RPC message")?; + break 'outer Err(result).context("failed to write RPC message") } } - None => { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - return Ok(()) - } + None => break 'outer Ok(()), } } } - } + }; + + response_channels.lock().await.take(); + this.connections.write().await.remove(&connection_id); + result }; self.connections @@ -226,6 +220,8 @@ impl Peer { .response_channels .lock() .await + .as_mut() + .ok_or_else(|| anyhow!("connection was closed"))? .insert(message_id, tx); connection .outgoing_tx @@ -520,8 +516,7 @@ mod tests { #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn, _) = Connection::in_memory(); - drop(server_conn); + let (client_conn, mut server_conn, _) = Connection::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -529,11 +524,14 @@ mod tests { smol::spawn(io_handler).detach(); smol::spawn(async move { incoming.next().await }).detach(); - let err = client - .request(connection_id, proto::Ping {}) - .await - .unwrap_err(); - assert_eq!(err.to_string(), "connection was closed"); + let response = smol::spawn(client.request(connection_id, proto::Ping {})); + let _request = server_conn.rx.next().await.unwrap().unwrap(); + + drop(server_conn); + assert_eq!( + response.await.unwrap_err().to_string(), + "connection was closed" + ); }); } } diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 92fca53e28335680f2cf7227e0eea32a68a54e8b..e9de319a1c6abf8c23efcbfd3e47622d133908dc 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -192,11 +192,15 @@ entity_messages!(channel_id, ChannelMessageSent); /// A stream of protobuf messages. pub struct MessageStream { stream: S, + encoding_buffer: Vec, } impl MessageStream { pub fn new(stream: S) -> Self { - Self { stream } + Self { + stream, + encoding_buffer: Vec::new(), + } } pub fn inner_mut(&mut self) -> &mut S { @@ -210,10 +214,12 @@ where { /// Write a given protobuf message to the stream. pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { - let mut buffer = Vec::with_capacity(message.encoded_len()); + self.encoding_buffer.resize(message.encoded_len(), 0); + self.encoding_buffer.clear(); message - .encode(&mut buffer) + .encode(&mut self.encoding_buffer) .map_err(|err| io::Error::from(err))?; + let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap(); self.stream.send(WebSocketMessage::Binary(buffer)).await?; Ok(()) } @@ -228,7 +234,10 @@ where while let Some(bytes) = self.stream.next().await { match bytes? { WebSocketMessage::Binary(bytes) => { - let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?; + self.encoding_buffer.clear(); + zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap(); + let envelope = Envelope::decode(self.encoding_buffer.as_slice()) + .map_err(io::Error::from)?; return Ok(envelope); } WebSocketMessage::Close(_) => break,