Detailed changes
@@ -867,6 +867,9 @@ name = "cc"
version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd"
+dependencies = [
+ "jobserver",
+]
[[package]]
name = "cexpr"
@@ -2614,6 +2617,15 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
+[[package]]
+name = "jobserver"
+version = "0.1.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa"
+dependencies = [
+ "libc",
+]
+
[[package]]
name = "jpeg-decoder"
version = "0.1.22"
@@ -6037,4 +6049,34 @@ dependencies = [
"serde 1.0.125",
"smol",
"tempdir",
+ "zstd",
+]
+
+[[package]]
+name = "zstd"
+version = "0.9.0+zstd.1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "07749a5dc2cb6b36661290245e350f15ec3bbb304e493db54a1d354480522ccd"
+dependencies = [
+ "zstd-safe",
+]
+
+[[package]]
+name = "zstd-safe"
+version = "4.1.1+zstd.1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c91c90f2c593b003603e5e0493c837088df4469da25aafff8bce42ba48caf079"
+dependencies = [
+ "libc",
+ "zstd-sys",
+]
+
+[[package]]
+name = "zstd-sys"
+version = "1.6.1+zstd.1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "615120c7a2431d16cf1cf979e7fc31ba7a5b5e5707b29c8a99e5dbf8a8392a33"
+dependencies = [
+ "cc",
+ "libc",
]
@@ -669,8 +669,9 @@ mod tests {
);
}
- #[crate::test(self)]
+ #[crate::test(self, retries = 5)]
fn test_wrap_shaped_line(cx: &mut crate::MutableAppContext) {
+ // This is failing intermittently on CI and we don't have time to figure it out
let font_cache = cx.font_cache().clone();
let font_system = cx.platform().fonts();
let text_layout_cache = TextLayoutCache::new(font_system.clone());
@@ -18,7 +18,7 @@ use scrypt::{
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::{StatusCode, Url};
-use tide::{log, Server};
+use tide::{log, Error, Server};
use zrpc::auth as zed_auth;
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
@@ -33,51 +33,48 @@ pub struct User {
pub is_admin: bool,
}
-pub struct VerifyToken;
-
-#[async_trait]
-impl tide::Middleware<Arc<AppState>> for VerifyToken {
- async fn handle(
- &self,
- mut request: Request,
- next: tide::Next<'_, Arc<AppState>>,
- ) -> tide::Result {
- let mut auth_header = request
- .header("Authorization")
- .ok_or_else(|| anyhow!("no authorization header"))?
- .last()
- .as_str()
- .split_whitespace();
-
- let user_id = UserId(
- auth_header
- .next()
- .ok_or_else(|| anyhow!("missing user id in authorization header"))?
- .parse()?,
- );
- let access_token = auth_header
- .next()
- .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
-
- let state = request.state().clone();
-
- let mut credentials_valid = false;
- for password_hash in state.db.get_access_token_hashes(user_id).await? {
- if verify_access_token(&access_token, &password_hash)? {
- credentials_valid = true;
- break;
- }
+pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
+ let mut auth_header = request
+ .header("Authorization")
+ .ok_or_else(|| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing authorization header"),
+ )
+ })?
+ .last()
+ .as_str()
+ .split_whitespace();
+ let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing user id in authorization header"),
+ )
+ })?);
+ let access_token = auth_header.next().ok_or_else(|| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing access token in authorization header"),
+ )
+ })?;
+
+ let state = request.state().clone();
+ let mut credentials_valid = false;
+ for password_hash in state.db.get_access_token_hashes(user_id).await? {
+ if verify_access_token(&access_token, &password_hash)? {
+ credentials_valid = true;
+ break;
}
+ }
- if credentials_valid {
- request.set_ext(user_id);
- Ok(next.run(request).await)
- } else {
- let mut response = tide::Response::new(StatusCode::Unauthorized);
- response.set_body("invalid credentials");
- Ok(response)
- }
+ if !credentials_valid {
+ Err(Error::new(
+ StatusCode::Unauthorized,
+ anyhow!("invalid credentials"),
+ ))?;
}
+
+ Ok(user_id)
}
#[async_trait]
@@ -263,11 +260,13 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
Ok(tide::Redirect::new("/").into())
}
+const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
+
pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
- db.create_access_token_hash(user_id, access_token_hash)
+ db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
Ok(access_token)
}
@@ -175,25 +175,48 @@ impl Db {
pub async fn create_access_token_hash(
&self,
user_id: UserId,
- access_token_hash: String,
+ access_token_hash: &str,
+ max_access_token_count: usize,
) -> Result<()> {
test_support!(self, {
- let query = "
- INSERT INTO access_tokens (user_id, hash)
- VALUES ($1, $2)
- ";
- sqlx::query(query)
+ let insert_query = "
+ INSERT INTO access_tokens (user_id, hash)
+ VALUES ($1, $2);
+ ";
+ let cleanup_query = "
+ DELETE FROM access_tokens
+ WHERE id IN (
+ SELECT id from access_tokens
+ WHERE user_id = $1
+ ORDER BY id DESC
+ OFFSET $3
+ )
+ ";
+
+ let mut tx = self.pool.begin().await?;
+ sqlx::query(insert_query)
.bind(user_id.0)
.bind(access_token_hash)
- .execute(&self.pool)
- .await
- .map(drop)
+ .execute(&mut tx)
+ .await?;
+ sqlx::query(cleanup_query)
+ .bind(user_id.0)
+ .bind(access_token_hash)
+ .bind(max_access_token_count as u32)
+ .execute(&mut tx)
+ .await?;
+ tx.commit().await
})
}
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
test_support!(self, {
- let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
+ let query = "
+ SELECT hash
+ FROM access_tokens
+ WHERE user_id = $1
+ ORDER BY id DESC
+ ";
sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.pool)
@@ -652,4 +675,36 @@ pub mod tests {
assert_eq!(msg1_id, msg3_id);
assert_eq!(msg2_id, msg4_id);
}
-}
+
+ #[gpui::test]
+ async fn test_create_access_tokens() {
+ let test_db = TestDb::new();
+ let db = test_db.db();
+ let user = db.create_user("the-user", false).await.unwrap();
+
+ db.create_access_token_hash(user, "h1", 3).await.unwrap();
+ db.create_access_token_hash(user, "h2", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h2".to_string(), "h1".to_string()]
+ );
+
+ db.create_access_token_hash(user, "h3", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+ );
+
+ db.create_access_token_hash(user, "h4", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+ );
+
+ db.create_access_token_hash(user, "h5", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+ );
+ }
+}
@@ -1,7 +1,7 @@
mod store;
use super::{
- auth,
+ auth::process_auth_header,
db::{ChannelId, MessageId, UserId},
AppState,
};
@@ -885,8 +885,7 @@ where
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = Server::new(app.state().clone(), rpc.clone(), None);
- app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
- let user_id = request.ext::<UserId>().copied();
+ app.at("/rpc").get(move |request: Request<Arc<AppState>>| {
let server = server.clone();
async move {
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -894,8 +893,11 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
let upgrade_requested = connection_upgrade && upgrade_to_websocket;
+ let client_protocol_version: Option<u32> = request
+ .header("X-Zed-Protocol-Version")
+ .and_then(|v| v.as_str().parse().ok());
- if !upgrade_requested {
+ if !upgrade_requested || client_protocol_version != Some(zrpc::PROTOCOL_VERSION) {
return Ok(Response::new(StatusCode::UpgradeRequired));
}
@@ -904,6 +906,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
None => return Err(anyhow!("expected sec-websocket-key"))?,
};
+ let user_id = process_auth_header(&request).await?;
+
let mut response = Response::new(StatusCode::SwitchingProtocols);
response.insert_header(UPGRADE, "websocket");
response.insert_header(CONNECTION, "Upgrade");
@@ -914,10 +918,17 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let http_res: &mut tide::http::Response = response.as_mut();
let upgrade_receiver = http_res.recv_upgrade().await;
let addr = request.remote().unwrap_or("unknown").to_string();
- let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
- server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
+ server
+ .handle_connection(
+ Connection::new(
+ WebSocketStream::from_raw_socket(stream, Role::Server, None).await,
+ ),
+ addr,
+ user_id,
+ )
+ .await;
}
});
@@ -11,6 +11,7 @@ title = "$text.0"
avatar_width = 20
icon_color = "$text.2.color"
avatar = { corner_radius = 10, border = { width = 1, color = "#00000088" } }
+outdated_warning = { extends = "$text.2", size = 13 }
[workspace.titlebar.offline_icon]
padding = { right = 4 }
@@ -29,7 +29,6 @@ use smol::Timer;
use std::{
cell::RefCell,
cmp::{self, Ordering},
- iter::FromIterator,
mem,
ops::{Range, RangeInclusive},
path::Path,
@@ -299,7 +298,7 @@ pub struct Editor {
pending_selection: Option<Selection>,
next_selection_id: usize,
add_selections_state: Option<AddSelectionsState>,
- select_larger_syntax_node_stack: Vec<Vec<Selection>>,
+ select_larger_syntax_node_stack: Vec<Arc<[Selection]>>,
scroll_position: Vector2F,
scroll_top_anchor: Anchor,
autoscroll_requested: bool,
@@ -511,15 +510,14 @@ impl Editor {
return false;
}
- let first_cursor_top = self
- .selections(cx)
+ let selections = self.selections(cx);
+ let first_cursor_top = selections
.first()
.unwrap()
.head()
.to_display_point(&display_map, Bias::Left)
.row() as f32;
- let last_cursor_bottom = self
- .selections(cx)
+ let last_cursor_bottom = selections
.last()
.unwrap()
.head()
@@ -561,12 +559,13 @@ impl Editor {
scroll_width: f32,
max_glyph_width: f32,
layouts: &[text_layout::Line],
- cx: &mut MutableAppContext,
+ cx: &mut ViewContext<Self>,
) -> bool {
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let mut target_left = std::f32::INFINITY;
let mut target_right = 0.0_f32;
- for selection in self.selections(cx) {
+ for selection in selections.iter() {
let head = selection.head().to_display_point(&display_map, Bias::Left);
let start_column = head.column().saturating_sub(3);
let end_column = cmp::min(display_map.line_len(head.row()), head.column() + 3);
@@ -655,12 +654,10 @@ impl Editor {
fn end_selection(&mut self, cx: &mut ViewContext<Self>) {
if let Some(selection) = self.pending_selection.take() {
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
let ix = self.selection_insertion_index(&selections, &selection.start, cx.as_ref());
selections.insert(ix, selection);
self.update_selections(selections, false, cx);
- } else {
- log::error!("end_selection dispatched with no pending selection");
}
}
@@ -669,12 +666,13 @@ impl Editor {
}
pub fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
- let selections = self.selections(cx.as_ref());
if let Some(pending_selection) = self.pending_selection.take() {
+ let selections = self.selections(cx);
if selections.is_empty() {
self.update_selections(vec![pending_selection], true, cx);
}
} else {
+ let selections = self.selections(cx);
let mut oldest_selection = selections.iter().min_by_key(|s| s.id).unwrap().clone();
if selections.len() == 1 {
oldest_selection.start = oldest_selection.head().clone();
@@ -743,8 +741,9 @@ impl Editor {
pub fn insert(&mut self, action: &Insert, cx: &mut ViewContext<Self>) {
let mut old_selections = SmallVec::<[_; 32]>::new();
{
+ let selections = self.selections(cx);
let buffer = self.buffer.read(cx);
- for selection in self.selections(cx.as_ref()) {
+ for selection in selections.iter() {
let start = selection.start.to_offset(buffer);
let end = selection.end.to_offset(buffer);
old_selections.push((selection.id, start..end));
@@ -790,7 +789,7 @@ impl Editor {
pub fn backspace(&mut self, _: &Backspace, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
{
let buffer = self.buffer.read(cx);
@@ -814,7 +813,7 @@ impl Editor {
pub fn delete(&mut self, _: &Delete, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
let buffer = self.buffer.read(cx);
for selection in &mut selections {
@@ -837,14 +836,14 @@ impl Editor {
pub fn delete_line(&mut self, _: &DeleteLine, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let app = cx.as_ref();
- let buffer = self.buffer.read(app);
+ let buffer = self.buffer.read(cx);
let mut new_cursors = Vec::new();
let mut edit_ranges = Vec::new();
- let mut selections = self.selections(app).iter().peekable();
+ let mut selections = selections.iter().peekable();
while let Some(selection) = selections.next() {
let mut rows = selection.spanned_rows(false, &display_map).buffer_rows;
let goal_display_column = selection
@@ -914,7 +913,7 @@ impl Editor {
pub fn duplicate_line(&mut self, _: &DuplicateLine, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
// Temporarily bias selections right to allow newly duplicate lines to push them down
// when the selections are at the beginning of a line.
@@ -974,8 +973,8 @@ impl Editor {
pub fn move_line_up(&mut self, _: &MoveLineUp, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let app = cx.as_ref();
let buffer = self.buffer.read(cx);
let mut edits = Vec::new();
@@ -983,7 +982,7 @@ impl Editor {
let mut old_folds = Vec::new();
let mut new_folds = Vec::new();
- let mut selections = self.selections(app).iter().peekable();
+ let mut selections = selections.iter().peekable();
let mut contiguous_selections = Vec::new();
while let Some(selection) = selections.next() {
// Accumulate contiguous regions of rows that we want to move.
@@ -1064,8 +1063,8 @@ impl Editor {
pub fn move_line_down(&mut self, _: &MoveLineDown, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let app = cx.as_ref();
let buffer = self.buffer.read(cx);
let mut edits = Vec::new();
@@ -1073,7 +1072,7 @@ impl Editor {
let mut old_folds = Vec::new();
let mut new_folds = Vec::new();
- let mut selections = self.selections(app).iter().peekable();
+ let mut selections = selections.iter().peekable();
let mut contiguous_selections = Vec::new();
while let Some(selection) = selections.next() {
// Accumulate contiguous regions of rows that we want to move.
@@ -1151,7 +1150,7 @@ impl Editor {
pub fn cut(&mut self, _: &Cut, cx: &mut ViewContext<Self>) {
self.start_transaction(cx);
let mut text = String::new();
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
let mut clipboard_selections = Vec::with_capacity(selections.len());
{
let buffer = self.buffer.read(cx);
@@ -1186,12 +1185,12 @@ impl Editor {
}
pub fn copy(&mut self, _: &Copy, cx: &mut ViewContext<Self>) {
+ let selections = self.selections(cx);
let buffer = self.buffer.read(cx);
let max_point = buffer.max_point();
let mut text = String::new();
- let selections = self.selections(cx.as_ref());
let mut clipboard_selections = Vec::with_capacity(selections.len());
- for selection in selections {
+ for selection in selections.iter() {
let mut start = selection.start.to_point(buffer);
let mut end = selection.end.to_point(buffer);
let is_entire_line = start == end;
@@ -1218,24 +1217,28 @@ impl Editor {
if let Some(item) = cx.as_mut().read_from_clipboard() {
let clipboard_text = item.text();
if let Some(mut clipboard_selections) = item.metadata::<Vec<ClipboardSelection>>() {
- let selections = self.selections(cx.as_ref()).to_vec();
+ let selections = self.selections(cx);
+ let all_selections_were_entire_line =
+ clipboard_selections.iter().all(|s| s.is_entire_line);
if clipboard_selections.len() != selections.len() {
- let merged_selection = ClipboardSelection {
- len: clipboard_selections.iter().map(|s| s.len).sum(),
- is_entire_line: clipboard_selections.iter().all(|s| s.is_entire_line),
- };
clipboard_selections.clear();
- clipboard_selections.push(merged_selection);
}
self.start_transaction(cx);
+ let mut start_offset = 0;
let mut new_selections = Vec::with_capacity(selections.len());
- let mut clipboard_chars = clipboard_text.chars().cycle();
- for (selection, clipboard_selection) in
- selections.iter().zip(clipboard_selections.iter().cycle())
- {
- let to_insert =
- String::from_iter(clipboard_chars.by_ref().take(clipboard_selection.len));
+ for (i, selection) in selections.iter().enumerate() {
+ let to_insert;
+ let entire_line;
+ if let Some(clipboard_selection) = clipboard_selections.get(i) {
+ let end_offset = start_offset + clipboard_selection.len;
+ to_insert = &clipboard_text[start_offset..end_offset];
+ entire_line = clipboard_selection.is_entire_line;
+ start_offset = end_offset
+ } else {
+ to_insert = clipboard_text.as_str();
+ entire_line = all_selections_were_entire_line;
+ }
self.buffer.update(cx, |buffer, cx| {
let selection_start = selection.start.to_point(&*buffer);
@@ -1246,7 +1249,7 @@ impl Editor {
// selection was copied. If this selection is also currently empty,
// then paste the line before the current line of the buffer.
let new_selection_start = selection.end.bias_right(buffer);
- if selection_start == selection_end && clipboard_selection.is_entire_line {
+ if selection_start == selection_end && entire_line {
let line_start = Point::new(selection_start.row, 0);
buffer.edit(Some(line_start..line_start), to_insert, cx);
} else {
@@ -1281,8 +1284,7 @@ impl Editor {
pub fn move_left(&mut self, _: &MoveLeft, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let app = cx.as_ref();
- let mut selections = self.selections(app).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
for selection in &mut selections {
let start = selection.start.to_display_point(&display_map, Bias::Left);
@@ -1305,7 +1307,7 @@ impl Editor {
pub fn select_left(&mut self, _: &SelectLeft, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
let buffer = self.buffer.read(cx);
for selection in &mut selections {
@@ -1321,7 +1323,7 @@ impl Editor {
pub fn move_right(&mut self, _: &MoveRight, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
for selection in &mut selections {
let start = selection.start.to_display_point(&display_map, Bias::Left);
@@ -1344,7 +1346,7 @@ impl Editor {
pub fn select_right(&mut self, _: &SelectRight, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
let app = cx.as_ref();
let buffer = self.buffer.read(app);
@@ -1364,7 +1366,7 @@ impl Editor {
if matches!(self.mode, EditorMode::SingleLine) {
cx.propagate_action();
} else {
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
for selection in &mut selections {
let start = selection.start.to_display_point(&display_map, Bias::Left);
@@ -1387,7 +1389,7 @@ impl Editor {
pub fn select_up(&mut self, _: &SelectUp, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
let app = cx.as_ref();
let buffer = self.buffer.read(app);
@@ -1406,7 +1408,7 @@ impl Editor {
cx.propagate_action();
} else {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let mut selections = self.selections(cx.as_ref()).to_vec();
+ let mut selections = self.selections(cx).to_vec();
{
for selection in &mut selections {
let start = selection.start.to_display_point(&display_map, Bias::Left);
@@ -1490,8 +1492,26 @@ impl Editor {
cx: &mut ViewContext<Self>,
) {
self.start_transaction(cx);
- self.select_to_previous_word_boundary(&SelectToPreviousWordBoundary, cx);
- self.backspace(&Backspace, cx);
+ let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
+ let mut selections = self.selections(cx).to_vec();
+ {
+ let buffer = self.buffer.read(cx);
+ for selection in &mut selections {
+ let range = selection.point_range(buffer);
+ if range.start == range.end {
+ let head = selection.head().to_display_point(&display_map, Bias::Left);
+ let cursor = display_map.anchor_before(
+ movement::prev_word_boundary(&display_map, head).unwrap(),
+ Bias::Right,
+ );
+ selection.set_head(&buffer, cursor);
+ selection.goal = SelectionGoal::None;
+ }
+ }
+ }
+
+ self.update_selections(selections, true, cx);
+ self.insert(&Insert(String::new()), cx);
self.end_transaction(cx);
}
@@ -1542,8 +1562,26 @@ impl Editor {
cx: &mut ViewContext<Self>,
) {
self.start_transaction(cx);
- self.select_to_next_word_boundary(&SelectToNextWordBoundary, cx);
- self.delete(&Delete, cx);
+ let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
+ let mut selections = self.selections(cx).to_vec();
+ {
+ let buffer = self.buffer.read(cx);
+ for selection in &mut selections {
+ let range = selection.point_range(buffer);
+ if range.start == range.end {
+ let head = selection.head().to_display_point(&display_map, Bias::Left);
+ let cursor = display_map.anchor_before(
+ movement::next_word_boundary(&display_map, head).unwrap(),
+ Bias::Right,
+ );
+ selection.set_head(&buffer, cursor);
+ selection.goal = SelectionGoal::None;
+ }
+ }
+ }
+
+ self.update_selections(selections, true, cx);
+ self.insert(&Insert(String::new()), cx);
self.end_transaction(cx);
}
@@ -1661,7 +1699,7 @@ impl Editor {
}
pub fn select_to_beginning(&mut self, _: &SelectToBeginning, cx: &mut ViewContext<Self>) {
- let mut selection = self.selections(cx.as_ref()).last().unwrap().clone();
+ let mut selection = self.selections(cx).last().unwrap().clone();
selection.set_head(self.buffer.read(cx), Anchor::min());
self.update_selections(vec![selection], true, cx);
}
@@ -1680,7 +1718,7 @@ impl Editor {
}
pub fn select_to_end(&mut self, _: &SelectToEnd, cx: &mut ViewContext<Self>) {
- let mut selection = self.selections(cx.as_ref()).last().unwrap().clone();
+ let mut selection = self.selections(cx).last().unwrap().clone();
selection.set_head(self.buffer.read(cx), Anchor::max());
self.update_selections(vec![selection], true, cx);
}
@@ -1698,8 +1736,8 @@ impl Editor {
pub fn select_line(&mut self, _: &SelectLine, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- let buffer = self.buffer.read(cx);
let mut selections = self.selections(cx).to_vec();
+ let buffer = self.buffer.read(cx);
let max_point = buffer.max_point();
for selection in &mut selections {
let rows = selection.spanned_rows(true, &display_map).buffer_rows;
@@ -1715,12 +1753,12 @@ impl Editor {
_: &SplitSelectionIntoLines,
cx: &mut ViewContext<Self>,
) {
- let app = cx.as_ref();
- let buffer = self.buffer.read(app);
+ let selections = self.selections(cx);
+ let buffer = self.buffer.read(cx);
let mut to_unfold = Vec::new();
let mut new_selections = Vec::new();
- for selection in self.selections(app) {
+ for selection in selections.iter() {
let range = selection.point_range(buffer).sorted();
if range.start.row != range.end.row {
new_selections.push(Selection {
@@ -1860,14 +1898,14 @@ impl Editor {
_: &SelectLargerSyntaxNode,
cx: &mut ViewContext<Self>,
) {
+ let old_selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let buffer = self.buffer.read(cx);
let mut stack = mem::take(&mut self.select_larger_syntax_node_stack);
let mut selected_larger_node = false;
- let old_selections = self.selections(cx).to_vec();
let mut new_selection_ranges = Vec::new();
- for selection in &old_selections {
+ for selection in old_selections.iter() {
let old_range = selection.start.to_offset(buffer)..selection.end.to_offset(buffer);
let mut new_range = old_range.clone();
while let Some(containing_range) = buffer.range_for_syntax_ancestor(new_range.clone()) {
@@ -1908,7 +1946,7 @@ impl Editor {
) {
let mut stack = mem::take(&mut self.select_larger_syntax_node_stack);
if let Some(selections) = stack.pop() {
- self.update_selections(selections, true, cx);
+ self.update_selections(selections.to_vec(), true, cx);
}
self.select_larger_syntax_node_stack = stack;
}
@@ -1918,8 +1956,8 @@ impl Editor {
_: &MoveToEnclosingBracket,
cx: &mut ViewContext<Self>,
) {
+ let mut selections = self.selections(cx).to_vec();
let buffer = self.buffer.read(cx.as_ref());
- let mut selections = self.selections(cx.as_ref()).to_vec();
for selection in &mut selections {
let selection_range = selection.offset_range(buffer);
if let Some((open_range, close_range)) =
@@ -2033,12 +2071,14 @@ impl Editor {
}
}
- fn selections<'a>(&self, cx: &'a AppContext) -> &'a [Selection] {
+ fn selections(&mut self, cx: &mut ViewContext<Self>) -> Arc<[Selection]> {
+ self.end_selection(cx);
let buffer = self.buffer.read(cx);
- &buffer
+ buffer
.selection_set(self.selection_set_id)
.unwrap()
.selections
+ .clone()
}
fn update_selections(
@@ -2112,8 +2152,9 @@ impl Editor {
pub fn fold(&mut self, _: &Fold, cx: &mut ViewContext<Self>) {
let mut fold_ranges = Vec::new();
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
- for selection in self.selections(cx) {
+ for selection in selections.iter() {
let range = selection.display_range(&display_map).sorted();
let buffer_start_row = range.start.to_buffer_point(&display_map, Bias::Left).row;
@@ -2134,10 +2175,10 @@ impl Editor {
}
pub fn unfold(&mut self, _: &Unfold, cx: &mut ViewContext<Self>) {
+ let selections = self.selections(cx);
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let buffer = self.buffer.read(cx);
- let ranges = self
- .selections(cx)
+ let ranges = selections
.iter()
.map(|s| {
let range = s.display_range(&display_map).sorted();
@@ -2195,9 +2236,9 @@ impl Editor {
}
pub fn fold_selected_ranges(&mut self, _: &FoldSelectedRanges, cx: &mut ViewContext<Self>) {
+ let selections = self.selections(cx);
let buffer = self.buffer.read(cx);
- let ranges = self
- .selections(cx.as_ref())
+ let ranges = selections
.iter()
.map(|s| s.point_range(buffer).sorted())
.collect();
@@ -3223,24 +3264,13 @@ mod tests {
);
});
- view.update(cx, |view, cx| {
- view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx);
- assert_eq!(
- view.selection_ranges(cx),
- &[
- DisplayPoint::new(0, 3)..DisplayPoint::new(0, 3),
- DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0),
- ]
- );
- });
-
view.update(cx, |view, cx| {
view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx);
assert_eq!(
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 0)..DisplayPoint::new(0, 0),
- DisplayPoint::new(0, 24)..DisplayPoint::new(0, 24),
+ DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0),
]
);
});
@@ -3267,24 +3297,13 @@ mod tests {
);
});
- view.update(cx, |view, cx| {
- view.move_to_next_word_boundary(&MoveToNextWordBoundary, cx);
- assert_eq!(
- view.selection_ranges(cx),
- &[
- DisplayPoint::new(0, 4)..DisplayPoint::new(0, 4),
- DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0),
- ]
- );
- });
-
view.update(cx, |view, cx| {
view.move_to_next_word_boundary(&MoveToNextWordBoundary, cx);
assert_eq!(
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 7)..DisplayPoint::new(0, 7),
- DisplayPoint::new(2, 0)..DisplayPoint::new(2, 0),
+ DisplayPoint::new(1, 0)..DisplayPoint::new(1, 0),
]
);
});
@@ -3295,7 +3314,7 @@ mod tests {
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 9)..DisplayPoint::new(0, 9),
- DisplayPoint::new(2, 2)..DisplayPoint::new(2, 2),
+ DisplayPoint::new(2, 3)..DisplayPoint::new(2, 3),
]
);
});
@@ -3307,7 +3326,7 @@ mod tests {
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 10)..DisplayPoint::new(0, 9),
- DisplayPoint::new(2, 3)..DisplayPoint::new(2, 2),
+ DisplayPoint::new(2, 4)..DisplayPoint::new(2, 3),
]
);
});
@@ -3318,7 +3337,7 @@ mod tests {
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 10)..DisplayPoint::new(0, 7),
- DisplayPoint::new(2, 3)..DisplayPoint::new(2, 0),
+ DisplayPoint::new(2, 4)..DisplayPoint::new(2, 2),
]
);
});
@@ -3329,37 +3348,7 @@ mod tests {
view.selection_ranges(cx),
&[
DisplayPoint::new(0, 10)..DisplayPoint::new(0, 9),
- DisplayPoint::new(2, 3)..DisplayPoint::new(2, 2),
- ]
- );
- });
-
- view.update(cx, |view, cx| {
- view.delete_to_next_word_boundary(&DeleteToNextWordBoundary, cx);
- assert_eq!(
- view.display_text(cx),
- "use std::s::{foo, bar}\n\n {az.qux()}"
- );
- assert_eq!(
- view.selection_ranges(cx),
- &[
- DisplayPoint::new(0, 10)..DisplayPoint::new(0, 10),
- DisplayPoint::new(2, 3)..DisplayPoint::new(2, 3),
- ]
- );
- });
-
- view.update(cx, |view, cx| {
- view.delete_to_previous_word_boundary(&DeleteToPreviousWordBoundary, cx);
- assert_eq!(
- view.display_text(cx),
- "use std::::{foo, bar}\n\n az.qux()}"
- );
- assert_eq!(
- view.selection_ranges(cx),
- &[
- DisplayPoint::new(0, 9)..DisplayPoint::new(0, 9),
- DisplayPoint::new(2, 2)..DisplayPoint::new(2, 2),
+ DisplayPoint::new(2, 4)..DisplayPoint::new(2, 3),
]
);
});
@@ -3415,11 +3404,52 @@ mod tests {
view.move_to_previous_word_boundary(&MoveToPreviousWordBoundary, cx);
assert_eq!(
view.selection_ranges(cx),
- &[DisplayPoint::new(1, 15)..DisplayPoint::new(1, 15)]
+ &[DisplayPoint::new(1, 14)..DisplayPoint::new(1, 14)]
);
});
}
+ #[gpui::test]
+ fn test_delete_to_word_boundary(cx: &mut gpui::MutableAppContext) {
+ let buffer = cx.add_model(|cx| Buffer::new(0, "one two three four", cx));
+ let settings = settings::test(&cx).1;
+ let (_, view) = cx.add_window(Default::default(), |cx| {
+ build_editor(buffer.clone(), settings, cx)
+ });
+
+ view.update(cx, |view, cx| {
+ view.select_display_ranges(
+ &[
+ // an empty selection - the preceding word fragment is deleted
+ DisplayPoint::new(0, 2)..DisplayPoint::new(0, 2),
+ // characters selected - they are deleted
+ DisplayPoint::new(0, 9)..DisplayPoint::new(0, 12),
+ ],
+ cx,
+ )
+ .unwrap();
+ view.delete_to_previous_word_boundary(&DeleteToPreviousWordBoundary, cx);
+ });
+
+ assert_eq!(buffer.read(cx).text(), "e two te four");
+
+ view.update(cx, |view, cx| {
+ view.select_display_ranges(
+ &[
+ // an empty selection - the following word fragment is deleted
+ DisplayPoint::new(0, 3)..DisplayPoint::new(0, 3),
+ // characters selected - they are deleted
+ DisplayPoint::new(0, 9)..DisplayPoint::new(0, 10),
+ ],
+ cx,
+ )
+ .unwrap();
+ view.delete_to_next_word_boundary(&DeleteToNextWordBoundary, cx);
+ });
+
+ assert_eq!(buffer.read(cx).text(), "e t te our");
+ }
+
#[gpui::test]
fn test_backspace(cx: &mut gpui::MutableAppContext) {
let buffer = cx.add_model(|cx| {
@@ -3685,7 +3715,7 @@ mod tests {
#[gpui::test]
fn test_clipboard(cx: &mut gpui::MutableAppContext) {
- let buffer = cx.add_model(|cx| Buffer::new(0, "one two three four five six ", cx));
+ let buffer = cx.add_model(|cx| Buffer::new(0, "oneโ
two three four five six ", cx));
let settings = settings::test(&cx).1;
let view = cx
.add_window(Default::default(), |cx| {
@@ -3695,7 +3725,7 @@ mod tests {
// Cut with three selections. Clipboard text is divided into three slices.
view.update(cx, |view, cx| {
- view.select_ranges(vec![0..4, 8..14, 19..24], false, cx);
+ view.select_ranges(vec![0..7, 11..17, 22..27], false, cx);
view.cut(&Cut, cx);
assert_eq!(view.display_text(cx), "two four six ");
});
@@ -3704,13 +3734,13 @@ mod tests {
view.update(cx, |view, cx| {
view.select_ranges(vec![4..4, 9..9, 13..13], false, cx);
view.paste(&Paste, cx);
- assert_eq!(view.display_text(cx), "two one four three six five ");
+ assert_eq!(view.display_text(cx), "two oneโ
four three six five ");
assert_eq!(
view.selection_ranges(cx),
&[
- DisplayPoint::new(0, 8)..DisplayPoint::new(0, 8),
- DisplayPoint::new(0, 19)..DisplayPoint::new(0, 19),
- DisplayPoint::new(0, 28)..DisplayPoint::new(0, 28)
+ DisplayPoint::new(0, 11)..DisplayPoint::new(0, 11),
+ DisplayPoint::new(0, 22)..DisplayPoint::new(0, 22),
+ DisplayPoint::new(0, 31)..DisplayPoint::new(0, 31)
]
);
});
@@ -3719,13 +3749,13 @@ mod tests {
// match the number of slices in the clipboard, the entire clipboard text
// is pasted at each cursor.
view.update(cx, |view, cx| {
- view.select_ranges(vec![0..0, 28..28], false, cx);
+ view.select_ranges(vec![0..0, 31..31], false, cx);
view.insert(&Insert("( ".into()), cx);
view.paste(&Paste, cx);
view.insert(&Insert(") ".into()), cx);
assert_eq!(
view.display_text(cx),
- "( one three five ) two one four three six five ( one three five ) "
+ "( oneโ
three five ) two oneโ
four three six five ( oneโ
three five ) "
);
});
@@ -3734,7 +3764,7 @@ mod tests {
view.insert(&Insert("123\n4567\n89\n".into()), cx);
assert_eq!(
view.display_text(cx),
- "123\n4567\n89\n( one three five ) two one four three six five ( one three five ) "
+ "123\n4567\n89\n( oneโ
three five ) two oneโ
four three six five ( oneโ
three five ) "
);
});
@@ -3752,7 +3782,7 @@ mod tests {
view.cut(&Cut, cx);
assert_eq!(
view.display_text(cx),
- "13\n9\n( one three five ) two one four three six five ( one three five ) "
+ "13\n9\n( oneโ
three five ) two oneโ
four three six five ( oneโ
three five ) "
);
});
@@ -3771,7 +3801,7 @@ mod tests {
view.paste(&Paste, cx);
assert_eq!(
view.display_text(cx),
- "123\n4567\n9\n( 8ne three five ) two one four three six five ( one three five ) "
+ "123\n4567\n9\n( 8neโ
three five ) two oneโ
four three six five ( oneโ
three five ) "
);
assert_eq!(
view.selection_ranges(cx),
@@ -3805,7 +3835,7 @@ mod tests {
view.paste(&Paste, cx);
assert_eq!(
view.display_text(cx),
- "123\n123\n123\n67\n123\n9\n( 8ne three five ) two one four three six five ( one three five ) "
+ "123\n123\n123\n67\n123\n9\n( 8neโ
three five ) two oneโ
four three six five ( oneโ
three five ) "
);
assert_eq!(
view.selection_ranges(cx),
@@ -101,7 +101,10 @@ pub fn line_end(map: &DisplayMapSnapshot, point: DisplayPoint) -> Result<Display
Ok(map.clip_point(line_end, Bias::Left))
}
-pub fn prev_word_boundary(map: &DisplayMapSnapshot, point: DisplayPoint) -> Result<DisplayPoint> {
+pub fn prev_word_boundary(
+ map: &DisplayMapSnapshot,
+ mut point: DisplayPoint,
+) -> Result<DisplayPoint> {
let mut line_start = 0;
if point.row() > 0 {
if let Some(indent) = map.soft_wrap_indent(point.row() - 1) {
@@ -111,39 +114,52 @@ pub fn prev_word_boundary(map: &DisplayMapSnapshot, point: DisplayPoint) -> Resu
if point.column() == line_start {
if point.row() == 0 {
- Ok(DisplayPoint::new(0, 0))
+ return Ok(DisplayPoint::new(0, 0));
} else {
let row = point.row() - 1;
- Ok(map.clip_point(DisplayPoint::new(row, map.line_len(row)), Bias::Left))
+ point = map.clip_point(DisplayPoint::new(row, map.line_len(row)), Bias::Left);
}
- } else {
- let mut boundary = DisplayPoint::new(point.row(), 0);
- let mut column = 0;
- let mut prev_c = None;
- for c in map.chars_at(DisplayPoint::new(point.row(), 0)) {
- if column >= point.column() {
- break;
- }
+ }
- if prev_c.is_none() || char_kind(prev_c.unwrap()) != char_kind(c) {
- *boundary.column_mut() = column;
- }
+ let mut boundary = DisplayPoint::new(point.row(), 0);
+ let mut column = 0;
+ let mut prev_char_kind = CharKind::Newline;
+ for c in map.chars_at(DisplayPoint::new(point.row(), 0)) {
+ if column >= point.column() {
+ break;
+ }
- prev_c = Some(c);
- column += c.len_utf8() as u32;
+ let char_kind = char_kind(c);
+ if char_kind != prev_char_kind
+ && char_kind != CharKind::Whitespace
+ && char_kind != CharKind::Newline
+ {
+ *boundary.column_mut() = column;
}
- Ok(boundary)
+
+ prev_char_kind = char_kind;
+ column += c.len_utf8() as u32;
}
+ Ok(boundary)
}
pub fn next_word_boundary(
map: &DisplayMapSnapshot,
mut point: DisplayPoint,
) -> Result<DisplayPoint> {
- let mut prev_c = None;
+ let mut prev_char_kind = None;
for c in map.chars_at(point) {
- if prev_c.is_some() && (c == '\n' || char_kind(prev_c.unwrap()) != char_kind(c)) {
- break;
+ let char_kind = char_kind(c);
+ if let Some(prev_char_kind) = prev_char_kind {
+ if c == '\n' {
+ break;
+ }
+ if prev_char_kind != char_kind
+ && prev_char_kind != CharKind::Whitespace
+ && prev_char_kind != CharKind::Newline
+ {
+ break;
+ }
}
if c == '\n' {
@@ -152,7 +168,7 @@ pub fn next_word_boundary(
} else {
*point.column_mut() += c.len_utf8() as u32;
}
- prev_c = Some(c);
+ prev_char_kind = Some(char_kind);
}
Ok(point)
}
@@ -192,7 +208,7 @@ mod tests {
.unwrap();
let font_size = 14.0;
- let buffer = cx.add_model(|cx| Buffer::new(0, "a bcฮ defฮณ", cx));
+ let buffer = cx.add_model(|cx| Buffer::new(0, "a bcฮ defฮณ hiโjk", cx));
let display_map =
cx.add_model(|cx| DisplayMap::new(buffer, tab_size, font_id, font_size, None, cx));
let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
@@ -202,7 +218,7 @@ mod tests {
);
assert_eq!(
prev_word_boundary(&snapshot, DisplayPoint::new(0, 7)).unwrap(),
- DisplayPoint::new(0, 6)
+ DisplayPoint::new(0, 2)
);
assert_eq!(
prev_word_boundary(&snapshot, DisplayPoint::new(0, 6)).unwrap(),
@@ -210,7 +226,7 @@ mod tests {
);
assert_eq!(
prev_word_boundary(&snapshot, DisplayPoint::new(0, 2)).unwrap(),
- DisplayPoint::new(0, 1)
+ DisplayPoint::new(0, 0)
);
assert_eq!(
prev_word_boundary(&snapshot, DisplayPoint::new(0, 1)).unwrap(),
@@ -223,7 +239,7 @@ mod tests {
);
assert_eq!(
next_word_boundary(&snapshot, DisplayPoint::new(0, 1)).unwrap(),
- DisplayPoint::new(0, 2)
+ DisplayPoint::new(0, 6)
);
assert_eq!(
next_word_boundary(&snapshot, DisplayPoint::new(0, 2)).unwrap(),
@@ -231,7 +247,7 @@ mod tests {
);
assert_eq!(
next_word_boundary(&snapshot, DisplayPoint::new(0, 6)).unwrap(),
- DisplayPoint::new(0, 7)
+ DisplayPoint::new(0, 12)
);
assert_eq!(
next_word_boundary(&snapshot, DisplayPoint::new(0, 7)).unwrap(),
@@ -438,29 +438,37 @@ mod tests {
use crate::{
editor::{self, Insert},
fs::FakeFs,
- test::{temp_tree, test_app_state},
+ test::test_app_state,
workspace::Workspace,
};
use serde_json::json;
- use std::fs;
- use tempdir::TempDir;
+ use std::path::PathBuf;
#[gpui::test]
async fn test_matching_paths(mut cx: gpui::TestAppContext) {
- let tmp_dir = TempDir::new("example").unwrap();
- fs::create_dir(tmp_dir.path().join("a")).unwrap();
- fs::write(tmp_dir.path().join("a/banana"), "banana").unwrap();
- fs::write(tmp_dir.path().join("a/bandana"), "bandana").unwrap();
+ let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "a": {
+ "banana": "",
+ "bandana": "",
+ }
+ }),
+ )
+ .await;
cx.update(|cx| {
super::init(cx);
editor::init(cx);
});
- let app_state = cx.update(test_app_state);
let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(tmp_dir.path(), cx)
+ workspace.add_worktree(Path::new("/root"), cx)
})
.await
.unwrap();
@@ -572,17 +580,17 @@ mod tests {
#[gpui::test]
async fn test_single_file_worktrees(mut cx: gpui::TestAppContext) {
- let temp_dir = TempDir::new("test-single-file-worktrees").unwrap();
- let dir_path = temp_dir.path().join("the-parent-dir");
- let file_path = dir_path.join("the-file");
- fs::create_dir(&dir_path).unwrap();
- fs::write(&file_path, "").unwrap();
-
let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree("/root", json!({ "the-parent-dir": { "the-file": "" } }))
+ .await;
+
let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(&file_path, cx)
+ workspace.add_worktree(Path::new("/root/the-parent-dir/the-file"), cx)
})
.await
.unwrap();
@@ -620,18 +628,25 @@ mod tests {
#[gpui::test(retries = 5)]
async fn test_multiple_matches_with_same_relative_path(mut cx: gpui::TestAppContext) {
- let tmp_dir = temp_tree(json!({
- "dir1": { "a.txt": "" },
- "dir2": { "a.txt": "" }
- }));
-
let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "dir1": { "a.txt": "" },
+ "dir2": { "a.txt": "" }
+ }),
+ )
+ .await;
+
let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
workspace.open_paths(
- &[tmp_dir.path().join("dir1"), tmp_dir.path().join("dir2")],
+ &[PathBuf::from("/root/dir1"), PathBuf::from("/root/dir2")],
cx,
)
})
@@ -29,6 +29,8 @@ pub trait Fs: Send + Sync {
latency: Duration,
) -> Pin<Box<dyn Send + Stream<Item = Vec<fsevent::Event>>>>;
fn is_fake(&self) -> bool;
+ #[cfg(any(test, feature = "test-support"))]
+ fn as_fake(&self) -> &FakeFs;
}
#[derive(Clone, Debug)]
@@ -125,6 +127,11 @@ impl Fs for RealFs {
fn is_fake(&self) -> bool {
false
}
+
+ #[cfg(any(test, feature = "test-support"))]
+ fn as_fake(&self) -> &FakeFs {
+ panic!("called `RealFs::as_fake`")
+ }
}
#[derive(Clone, Debug)]
@@ -413,4 +420,9 @@ impl Fs for FakeFs {
fn is_fake(&self) -> bool {
true
}
+
+ #[cfg(any(test, feature = "test-support"))]
+ fn as_fake(&self) -> &FakeFs {
+ self
+ }
}
@@ -55,6 +55,8 @@ pub struct Client {
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
+ #[error("upgrade required")]
+ UpgradeRequired,
#[error("unauthorized")]
Unauthorized,
#[error("{0}")]
@@ -68,8 +70,10 @@ pub enum EstablishConnectionError {
impl From<WebsocketError> for EstablishConnectionError {
fn from(error: WebsocketError) -> Self {
if let WebsocketError::Http(response) = &error {
- if response.status() == StatusCode::UNAUTHORIZED {
- return EstablishConnectionError::Unauthorized;
+ match response.status() {
+ StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
+ StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
+ _ => {}
}
}
EstablishConnectionError::Other(error.into())
@@ -85,6 +89,7 @@ impl EstablishConnectionError {
#[derive(Copy, Clone, Debug)]
pub enum Status {
SignedOut,
+ UpgradeRequired,
Authenticating,
Connecting,
ConnectionError,
@@ -227,7 +232,7 @@ impl Client {
}
}));
}
- Status::SignedOut => {
+ Status::SignedOut | Status::UpgradeRequired => {
state._maintain_connection.take();
}
_ => {}
@@ -346,6 +351,7 @@ impl Client {
| Status::Reconnecting { .. }
| Status::Authenticating
| Status::Reauthenticating => return Ok(()),
+ Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
};
if was_disconnected {
@@ -388,22 +394,25 @@ impl Client {
self.set_connection(conn, cx).await;
Ok(())
}
- Err(err) => {
- if matches!(err, EstablishConnectionError::Unauthorized) {
- self.state.write().credentials.take();
- if used_keychain {
- cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
- self.set_status(Status::SignedOut, cx);
- self.authenticate_and_connect(cx).await
- } else {
- self.set_status(Status::ConnectionError, cx);
- Err(err)?
- }
+ Err(EstablishConnectionError::Unauthorized) => {
+ self.state.write().credentials.take();
+ if used_keychain {
+ cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+ self.set_status(Status::SignedOut, cx);
+ self.authenticate_and_connect(cx).await
} else {
self.set_status(Status::ConnectionError, cx);
- Err(err)?
+ Err(EstablishConnectionError::Unauthorized)?
}
}
+ Err(EstablishConnectionError::UpgradeRequired) => {
+ self.set_status(Status::UpgradeRequired, cx);
+ Err(EstablishConnectionError::UpgradeRequired)?
+ }
+ Err(error) => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(error)?
+ }
}
}
@@ -489,10 +498,12 @@ impl Client {
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Connection, EstablishConnectionError>> {
- let request = Request::builder().header(
- "Authorization",
- format!("{} {}", credentials.user_id, credentials.access_token),
- );
+ let request = Request::builder()
+ .header(
+ "Authorization",
+ format!("{} {}", credentials.user_id, credentials.access_token),
+ )
+ .header("X-Zed-Protocol-Version", zrpc::PROTOCOL_VERSION);
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
@@ -1,7 +1,7 @@
use crate::{
assets::Assets,
channel::ChannelList,
- fs::RealFs,
+ fs::FakeFs,
http::{HttpClient, Request, Response, ServerResponse},
language::LanguageRegistry,
rpc::{self, Client, Credentials, EstablishConnectionError},
@@ -177,7 +177,7 @@ pub fn test_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)),
rpc,
user_store,
- fs: Arc::new(RealFs),
+ fs: Arc::new(FakeFs::new()),
})
}
@@ -54,6 +54,7 @@ pub struct Titlebar {
pub offline_icon: OfflineIcon,
pub icon_color: Color,
pub avatar: ImageStyle,
+ pub outdated_warning: ContainedText,
}
#[derive(Clone, Deserialize)]
@@ -169,7 +170,7 @@ pub struct Selector {
pub active_item: ContainedLabel,
}
-#[derive(Debug, Deserialize)]
+#[derive(Clone, Debug, Deserialize)]
pub struct ContainedText {
#[serde(flatten)]
pub container: ContainerStyle,
@@ -1041,6 +1041,16 @@ impl Workspace {
.with_style(theme.workspace.titlebar.offline_icon.container)
.boxed(),
),
+ rpc::Status::UpgradeRequired => Some(
+ Label::new(
+ "Please update Zed to collaborate".to_string(),
+ theme.workspace.titlebar.outdated_warning.text.clone(),
+ )
+ .contained()
+ .with_style(theme.workspace.titlebar.outdated_warning.container)
+ .aligned()
+ .boxed(),
+ ),
_ => None,
}
}
@@ -1195,11 +1205,9 @@ mod tests {
editor::{Editor, Insert},
fs::FakeFs,
test::{temp_tree, test_app_state},
- worktree::WorktreeHandle,
};
use serde_json::json;
- use std::{collections::HashSet, fs};
- use tempdir::TempDir;
+ use std::collections::HashSet;
#[gpui::test]
async fn test_open_paths_action(mut cx: gpui::TestAppContext) {
@@ -1268,20 +1276,26 @@ mod tests {
#[gpui::test]
async fn test_open_entry(mut cx: gpui::TestAppContext) {
- let dir = temp_tree(json!({
- "a": {
- "file1": "contents 1",
- "file2": "contents 2",
- "file3": "contents 3",
- },
- }));
-
let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "a": {
+ "file1": "contents 1",
+ "file2": "contents 2",
+ "file3": "contents 3",
+ },
+ }),
+ )
+ .await;
let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(dir.path(), cx)
+ workspace.add_worktree(Path::new("/root"), cx)
})
.await
.unwrap();
@@ -1445,28 +1459,30 @@ mod tests {
#[gpui::test]
async fn test_save_conflicting_item(mut cx: gpui::TestAppContext) {
- let dir = temp_tree(json!({
- "a.txt": "",
- }));
-
let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "a.txt": "",
+ }),
+ )
+ .await;
+
let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(dir.path(), cx)
+ workspace.add_worktree(Path::new("/root"), cx)
})
.await
.unwrap();
- let tree = cx.read(|cx| {
- let mut trees = workspace.read(cx).worktrees().iter();
- trees.next().unwrap().clone()
- });
- tree.flush_fs_events(&cx).await;
// Open a file within an existing worktree.
cx.update(|cx| {
workspace.update(cx, |view, cx| {
- view.open_paths(&[dir.path().join("a.txt")], cx)
+ view.open_paths(&[PathBuf::from("/root/a.txt")], cx)
})
})
.await;
@@ -1477,7 +1493,12 @@ mod tests {
});
cx.update(|cx| editor.update(cx, |editor, cx| editor.insert(&Insert("x".into()), cx)));
- fs::write(dir.path().join("a.txt"), "changed").unwrap();
+ app_state
+ .fs
+ .as_fake()
+ .insert_file("/root/a.txt", "changed".to_string())
+ .await
+ .unwrap();
editor
.condition(&cx, |editor, cx| editor.has_conflict(cx))
.await;
@@ -1493,12 +1514,12 @@ mod tests {
#[gpui::test]
async fn test_open_and_save_new_file(mut cx: gpui::TestAppContext) {
- let dir = TempDir::new("test-new-file").unwrap();
let app_state = cx.update(test_app_state);
+ app_state.fs.as_fake().insert_dir("/root").await.unwrap();
let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(dir.path(), cx)
+ workspace.add_worktree(Path::new("/root"), cx)
})
.await
.unwrap();
@@ -1511,7 +1532,6 @@ mod tests {
.unwrap()
.clone()
});
- tree.flush_fs_events(&cx).await;
// Create a new untitled buffer
let editor = workspace.update(&mut cx, |workspace, cx| {
@@ -1537,7 +1557,7 @@ mod tests {
workspace.save_active_item(&Save, cx)
});
cx.simulate_new_path_selection(|parent_dir| {
- assert_eq!(parent_dir, dir.path());
+ assert_eq!(parent_dir, Path::new("/root"));
Some(parent_dir.join("the-new-name.rs"))
});
cx.read(|cx| {
@@ -1598,8 +1618,8 @@ mod tests {
async fn test_setting_language_when_saving_as_single_file_worktree(
mut cx: gpui::TestAppContext,
) {
- let dir = TempDir::new("test-new-file").unwrap();
let app_state = cx.update(test_app_state);
+ app_state.fs.as_fake().insert_dir("/root").await.unwrap();
let (_, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
// Create a new untitled buffer
@@ -1623,7 +1643,7 @@ mod tests {
workspace.update(&mut cx, |workspace, cx| {
workspace.save_active_item(&Save, cx)
});
- cx.simulate_new_path_selection(|_| Some(dir.path().join("the-new-name.rs")));
+ cx.simulate_new_path_selection(|_| Some(PathBuf::from("/root/the-new-name.rs")));
editor
.condition(&cx, |editor, cx| !editor.is_dirty(cx))
@@ -1640,7 +1660,7 @@ mod tests {
cx.update(init);
let app_state = cx.update(test_app_state);
- cx.dispatch_global_action(OpenNew(app_state));
+ cx.dispatch_global_action(OpenNew(app_state.clone()));
let window_id = *cx.window_ids().first().unwrap();
let workspace = cx.root_view::<Workspace>(window_id).unwrap();
let editor = workspace.update(&mut cx, |workspace, cx| {
@@ -1660,10 +1680,8 @@ mod tests {
workspace.save_active_item(&Save, cx)
});
- let dir = TempDir::new("test-new-empty-workspace").unwrap();
- cx.simulate_new_path_selection(|_| {
- Some(dir.path().canonicalize().unwrap().join("the-new-name"))
- });
+ app_state.fs.as_fake().insert_dir("/root").await.unwrap();
+ cx.simulate_new_path_selection(|_| Some(PathBuf::from("/root/the-new-name")));
editor
.condition(&cx, |editor, cx| editor.title(cx) == "the-new-name")
@@ -1676,20 +1694,26 @@ mod tests {
#[gpui::test]
async fn test_pane_actions(mut cx: gpui::TestAppContext) {
cx.update(|cx| pane::init(cx));
-
- let dir = temp_tree(json!({
- "a": {
- "file1": "contents 1",
- "file2": "contents 2",
- "file3": "contents 3",
- },
- }));
-
let app_state = cx.update(test_app_state);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "a": {
+ "file1": "contents 1",
+ "file2": "contents 2",
+ "file3": "contents 3",
+ },
+ }),
+ )
+ .await;
+
let (window_id, workspace) = cx.add_window(|cx| Workspace::new(&app_state, cx));
workspace
.update(&mut cx, |workspace, cx| {
- workspace.add_worktree(dir.path(), cx)
+ workspace.add_worktree(Path::new("/root"), cx)
})
.await
.unwrap();
@@ -161,6 +161,7 @@ impl Worktree {
entries_by_id_edits.push(Edit::Insert(PathEntry {
id: entry.id,
path: entry.path.clone(),
+ is_ignored: entry.is_ignored,
scan_id: 0,
}));
entries_by_path_edits.push(Edit::Insert(entry));
@@ -1083,7 +1084,7 @@ impl LocalWorktree {
async move {
let mut prev_snapshot = snapshot;
while let Ok(snapshot) = snapshots_to_send_rx.recv().await {
- let message = snapshot.build_update(&prev_snapshot, remote_id);
+ let message = snapshot.build_update(&prev_snapshot, remote_id, false);
match rpc.send(message).await {
Ok(()) => prev_snapshot = snapshot,
Err(err) => log::error!("error sending snapshot diff {}", err),
@@ -1140,6 +1141,7 @@ impl LocalWorktree {
let entries = snapshot
.entries_by_path
.cursor::<(), ()>()
+ .filter(|e| !e.is_ignored)
.map(Into::into)
.collect();
proto::ShareWorktree {
@@ -1373,11 +1375,24 @@ impl Snapshot {
self.id
}
- pub fn build_update(&self, other: &Self, worktree_id: u64) -> proto::UpdateWorktree {
+ pub fn build_update(
+ &self,
+ other: &Self,
+ worktree_id: u64,
+ include_ignored: bool,
+ ) -> proto::UpdateWorktree {
let mut updated_entries = Vec::new();
let mut removed_entries = Vec::new();
- let mut self_entries = self.entries_by_id.cursor::<(), ()>().peekable();
- let mut other_entries = other.entries_by_id.cursor::<(), ()>().peekable();
+ let mut self_entries = self
+ .entries_by_id
+ .cursor::<(), ()>()
+ .filter(|e| include_ignored || !e.is_ignored)
+ .peekable();
+ let mut other_entries = other
+ .entries_by_id
+ .cursor::<(), ()>()
+ .filter(|e| include_ignored || !e.is_ignored)
+ .peekable();
loop {
match (self_entries.peek(), other_entries.peek()) {
(Some(self_entry), Some(other_entry)) => match self_entry.id.cmp(&other_entry.id) {
@@ -1443,6 +1458,7 @@ impl Snapshot {
entries_by_id_edits.push(Edit::Insert(PathEntry {
id: entry.id,
path: entry.path.clone(),
+ is_ignored: entry.is_ignored,
scan_id,
}));
entries_by_path_edits.push(Edit::Insert(entry));
@@ -1526,6 +1542,7 @@ impl Snapshot {
PathEntry {
id: entry.id,
path: entry.path.clone(),
+ is_ignored: entry.is_ignored,
scan_id: self.scan_id,
},
&(),
@@ -1561,6 +1578,7 @@ impl Snapshot {
entries_by_id_edits.push(Edit::Insert(PathEntry {
id: entry.id,
path: entry.path.clone(),
+ is_ignored: entry.is_ignored,
scan_id: self.scan_id,
}));
entries_by_path_edits.push(Edit::Insert(entry));
@@ -1933,6 +1951,7 @@ impl sum_tree::Summary for EntrySummary {
struct PathEntry {
id: usize,
path: Arc<Path>,
+ is_ignored: bool,
scan_id: usize,
}
@@ -2412,7 +2431,8 @@ impl BackgroundScanner {
ignore_stack = ignore_stack.append(job.path.clone(), ignore.clone());
}
- let mut edits = Vec::new();
+ let mut entries_by_id_edits = Vec::new();
+ let mut entries_by_path_edits = Vec::new();
for mut entry in snapshot.child_entries(&job.path).cloned() {
let was_ignored = entry.is_ignored;
entry.is_ignored = ignore_stack.is_path_ignored(&entry.path, entry.is_dir());
@@ -2433,10 +2453,17 @@ impl BackgroundScanner {
}
if entry.is_ignored != was_ignored {
- edits.push(Edit::Insert(entry));
+ let mut path_entry = snapshot.entries_by_id.get(&entry.id, &()).unwrap().clone();
+ path_entry.scan_id = snapshot.scan_id;
+ path_entry.is_ignored = entry.is_ignored;
+ entries_by_id_edits.push(Edit::Insert(path_entry));
+ entries_by_path_edits.push(Edit::Insert(entry));
}
}
- self.snapshot.lock().entries_by_path.edit(edits, &());
+
+ let mut snapshot = self.snapshot.lock();
+ snapshot.entries_by_path.edit(entries_by_path_edits, &());
+ snapshot.entries_by_id.edit(entries_by_id_edits, &());
}
}
@@ -3000,10 +3027,10 @@ mod tests {
// Update the remote worktree. Check that it becomes consistent with the
// local worktree.
remote.update(&mut cx, |remote, cx| {
- let update_message = tree
- .read(cx)
- .snapshot()
- .build_update(&initial_snapshot, worktree_id);
+ let update_message =
+ tree.read(cx)
+ .snapshot()
+ .build_update(&initial_snapshot, worktree_id, true);
remote
.as_remote_mut()
.unwrap()
@@ -3175,6 +3202,7 @@ mod tests {
scanner.snapshot().check_invariants();
let mut events = Vec::new();
+ let mut snapshots = Vec::new();
let mut mutations_len = operations;
while mutations_len > 1 {
if !events.is_empty() && rng.gen_bool(0.4) {
@@ -3187,6 +3215,10 @@ mod tests {
events.extend(randomly_mutate_tree(root_dir.path(), 0.6, &mut rng).unwrap());
mutations_len -= 1;
}
+
+ if rng.gen_bool(0.2) {
+ snapshots.push(scanner.snapshot());
+ }
}
log::info!("Quiescing: {:#?}", events);
smol::block_on(scanner.process_events(events));
@@ -3200,7 +3232,40 @@ mod tests {
scanner.executor.clone(),
);
smol::block_on(new_scanner.scan_dirs()).unwrap();
- assert_eq!(scanner.snapshot().to_vec(), new_scanner.snapshot().to_vec());
+ assert_eq!(
+ scanner.snapshot().to_vec(true),
+ new_scanner.snapshot().to_vec(true)
+ );
+
+ for mut prev_snapshot in snapshots {
+ let include_ignored = rng.gen::<bool>();
+ if !include_ignored {
+ let mut entries_by_path_edits = Vec::new();
+ let mut entries_by_id_edits = Vec::new();
+ for entry in prev_snapshot
+ .entries_by_id
+ .cursor::<(), ()>()
+ .filter(|e| e.is_ignored)
+ {
+ entries_by_path_edits.push(Edit::Remove(PathKey(entry.path.clone())));
+ entries_by_id_edits.push(Edit::Remove(entry.id));
+ }
+
+ prev_snapshot
+ .entries_by_path
+ .edit(entries_by_path_edits, &());
+ prev_snapshot.entries_by_id.edit(entries_by_id_edits, &());
+ }
+
+ let update = scanner
+ .snapshot()
+ .build_update(&prev_snapshot, 0, include_ignored);
+ prev_snapshot.apply_update(update).unwrap();
+ assert_eq!(
+ prev_snapshot.to_vec(true),
+ scanner.snapshot().to_vec(include_ignored)
+ );
+ }
}
fn randomly_mutate_tree(
@@ -3390,10 +3455,12 @@ mod tests {
}
}
- fn to_vec(&self) -> Vec<(&Path, u64, bool)> {
+ fn to_vec(&self, include_ignored: bool) -> Vec<(&Path, u64, bool)> {
let mut paths = Vec::new();
for entry in self.entries_by_path.cursor::<(), ()>() {
- paths.push((entry.path.as_ref(), entry.inode, entry.is_ignored));
+ if include_ignored || !entry.is_ignored {
+ paths.push((entry.path.as_ref(), entry.inode, entry.is_ignored));
+ }
}
paths.sort_by(|a, b| a.0.cmp(&b.0));
paths
@@ -20,6 +20,7 @@ prost = "0.7"
rand = "0.8"
rsa = "0.4"
serde = { version = "1", features = ["derive"] }
+zstd = "0.9"
[build-dependencies]
prost-build = { git = "https://github.com/tokio-rs/prost", rev = "6cf97ea422b09d98de34643c4dda2d4f8b7e23e6" }
@@ -4,3 +4,5 @@ mod peer;
pub mod proto;
pub use conn::Connection;
pub use peer::*;
+
+pub const PROTOCOL_VERSION: u32 = 1;
@@ -87,7 +87,7 @@ pub struct Peer {
struct ConnectionState {
outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>,
- response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
+ response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
}
impl Peer {
@@ -115,7 +115,7 @@ impl Peer {
let connection_state = ConnectionState {
outgoing_tx,
next_message_id: Default::default(),
- response_channels: Default::default(),
+ response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
@@ -123,7 +123,7 @@ impl Peer {
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
- loop {
+ let result = 'outer: loop {
let read_message = reader.read_message().fuse();
futures::pin_mut!(read_message);
loop {
@@ -131,7 +131,7 @@ impl Peer {
incoming = read_message => match incoming {
Ok(incoming) => {
if let Some(responding_to) = incoming.responding_to {
- let channel = response_channels.lock().await.remove(&responding_to);
+ let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(incoming).await.ok();
} else {
@@ -140,9 +140,7 @@ impl Peer {
} else {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
if incoming_tx.send(envelope).await.is_err() {
- response_channels.lock().await.clear();
- this.connections.write().await.remove(&connection_id);
- return Ok(())
+ break 'outer Ok(())
}
} else {
log::error!("unable to construct a typed envelope");
@@ -152,28 +150,24 @@ impl Peer {
break;
}
Err(error) => {
- response_channels.lock().await.clear();
- this.connections.write().await.remove(&connection_id);
- Err(error).context("received invalid RPC message")?;
+ break 'outer Err(error).context("received invalid RPC message")
}
},
outgoing = outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
- response_channels.lock().await.clear();
- this.connections.write().await.remove(&connection_id);
- Err(result).context("failed to write RPC message")?;
+ break 'outer Err(result).context("failed to write RPC message")
}
}
- None => {
- response_channels.lock().await.clear();
- this.connections.write().await.remove(&connection_id);
- return Ok(())
- }
+ None => break 'outer Ok(()),
}
}
}
- }
+ };
+
+ response_channels.lock().await.take();
+ this.connections.write().await.remove(&connection_id);
+ result
};
self.connections
@@ -226,6 +220,8 @@ impl Peer {
.response_channels
.lock()
.await
+ .as_mut()
+ .ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx);
connection
.outgoing_tx
@@ -520,8 +516,7 @@ mod tests {
#[test]
fn test_io_error() {
smol::block_on(async move {
- let (client_conn, server_conn, _) = Connection::in_memory();
- drop(server_conn);
+ let (client_conn, mut server_conn, _) = Connection::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@@ -529,11 +524,14 @@ mod tests {
smol::spawn(io_handler).detach();
smol::spawn(async move { incoming.next().await }).detach();
- let err = client
- .request(connection_id, proto::Ping {})
- .await
- .unwrap_err();
- assert_eq!(err.to_string(), "connection was closed");
+ let response = smol::spawn(client.request(connection_id, proto::Ping {}));
+ let _request = server_conn.rx.next().await.unwrap().unwrap();
+
+ drop(server_conn);
+ assert_eq!(
+ response.await.unwrap_err().to_string(),
+ "connection was closed"
+ );
});
}
}
@@ -192,11 +192,15 @@ entity_messages!(channel_id, ChannelMessageSent);
/// A stream of protobuf messages.
pub struct MessageStream<S> {
stream: S,
+ encoding_buffer: Vec<u8>,
}
impl<S> MessageStream<S> {
pub fn new(stream: S) -> Self {
- Self { stream }
+ Self {
+ stream,
+ encoding_buffer: Vec::new(),
+ }
}
pub fn inner_mut(&mut self) -> &mut S {
@@ -210,10 +214,12 @@ where
{
/// Write a given protobuf message to the stream.
pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
- let mut buffer = Vec::with_capacity(message.encoded_len());
+ self.encoding_buffer.resize(message.encoded_len(), 0);
+ self.encoding_buffer.clear();
message
- .encode(&mut buffer)
+ .encode(&mut self.encoding_buffer)
.map_err(|err| io::Error::from(err))?;
+ let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
Ok(())
}
@@ -228,7 +234,10 @@ where
while let Some(bytes) = self.stream.next().await {
match bytes? {
WebSocketMessage::Binary(bytes) => {
- let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
+ self.encoding_buffer.clear();
+ zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
+ let envelope = Envelope::decode(self.encoding_buffer.as_slice())
+ .map_err(io::Error::from)?;
return Ok(envelope);
}
WebSocketMessage::Close(_) => break,