diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3bebd0f0beabb4d6a2670ff7700e863e88e00771..cfbdc2ca02f89119c8953c0f9733daa2b60402ee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: clean: false - name: Run tests - run: cargo test --no-fail-fast + run: cargo test --workspace --no-fail-fast bundle: name: Bundle app diff --git a/Cargo.lock b/Cargo.lock index 1d420bb992eb811f999740b90d6fc9005e030c43..7973316c9e337532685bfc546f22dae9d8a7bee1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -836,7 +836,7 @@ dependencies = [ "target_build_utils", "term", "toml 0.4.10", - "uuid", + "uuid 0.5.1", "walkdir", ] @@ -884,7 +884,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e7fb075b9b54e939006aa12e1f6cd2d3194041ff4ebe7f2efcbedf18f25b667" dependencies = [ "byteorder", - "uuid", + "uuid 0.5.1", ] [[package]] @@ -2963,7 +2963,7 @@ dependencies = [ "byteorder", "cfb", "encoding", - "uuid", + "uuid 0.5.1", ] [[package]] @@ -4784,6 +4784,7 @@ dependencies = [ "thiserror", "time 0.2.25", "url", + "uuid 0.8.2", "webpki", "webpki-roots", "whoami", @@ -5606,6 +5607,12 @@ dependencies = [ "sha1 0.2.0", ] +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" + [[package]] name = "value-bag" version = "1.0.0-alpha.7" @@ -5917,6 +5924,7 @@ dependencies = [ "http-auth-basic", "ignore", "image 0.23.14", + "indexmap", "lazy_static", "libc", "log", diff --git a/gpui/src/elements/list.rs b/gpui/src/elements/list.rs index 1a86e2935cd774837d2dbf03f12acd089c4e487b..3864bf3c80daf4f9e2a6c119351838c2aabb2bb3 100644 --- a/gpui/src/elements/list.rs +++ b/gpui/src/elements/list.rs @@ -603,7 +603,7 @@ mod tests { offset_in_item: 0., }, 40., - vec2f(0., 54.), + vec2f(0., -54.), true, &mut presenter.build_event_context(cx), ); @@ -654,7 +654,7 @@ mod tests { assert_eq!(state.0.borrow().scroll_top(&logical_scroll_top), 114.); } - #[crate::test(self, iterations = 10000, seed = 0)] + #[crate::test(self, iterations = 10, seed = 0)] fn test_random(cx: &mut crate::MutableAppContext, mut rng: StdRng) { let operations = env::var("OPERATIONS") .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) diff --git a/server/Cargo.toml b/server/Cargo.toml index b73c70102a311ecf1813a5bd85efa315e344dd93..b295ff21acf4d7d578b1f5a2ba4ccfed234684eb 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,9 @@ edition = "2018" name = "zed-server" version = "0.1.0" +[[bin]] +name = "zed-server" + [[bin]] name = "seed" required-features = ["seed-support"] @@ -47,7 +50,7 @@ default-features = false [dependencies.sqlx] version = "0.5.2" -features = ["runtime-async-std-rustls", "postgres", "time"] +features = ["runtime-async-std-rustls", "postgres", "time", "uuid"] [dev-dependencies] gpui = { path = "../gpui" } diff --git a/server/migrations/20210916123647_add_nonce_to_channel_messages.sql b/server/migrations/20210916123647_add_nonce_to_channel_messages.sql new file mode 100644 index 0000000000000000000000000000000000000000..ee4d4aa319f6417e854137332011115570153eae --- /dev/null +++ b/server/migrations/20210916123647_add_nonce_to_channel_messages.sql @@ -0,0 +1,4 @@ +ALTER TABLE "channel_messages" +ADD "nonce" UUID NOT NULL DEFAULT gen_random_uuid(); + +CREATE UNIQUE INDEX "index_channel_messages_nonce" ON "channel_messages" ("nonce"); diff --git a/server/src/bin/seed.rs b/server/src/bin/seed.rs index b259dc4c14b24ea8b1278be56a6610f2e5fa1f64..d2427d495c451497df0644dc0fc4d36e7ecaa4ea 100644 --- a/server/src/bin/seed.rs +++ b/server/src/bin/seed.rs @@ -73,7 +73,7 @@ async fn main() { for timestamp in timestamps { let sender_id = *zed_user_ids.choose(&mut rng).unwrap(); let body = lipsum::lipsum_words(rng.gen_range(1..=50)); - db.create_channel_message(channel_id, sender_id, &body, timestamp) + db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen()) .await .expect("failed to insert message"); } diff --git a/server/src/db.rs b/server/src/db.rs index 8d2199a9f33a39c2b0f65d0e5eb60560c3ad2ab0..14ad85b68af2e06148c02d12dc74790fa2b5b0c9 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -1,7 +1,7 @@ use anyhow::Context; use async_std::task::{block_on, yield_now}; use serde::Serialize; -use sqlx::{FromRow, Result}; +use sqlx::{types::Uuid, FromRow, Result}; use time::OffsetDateTime; pub use async_sqlx_session::PostgresSessionStore as SessionStore; @@ -128,10 +128,23 @@ impl Db { requester_id: UserId, ids: impl Iterator, ) -> Result> { + let mut include_requester = false; + let ids = ids + .map(|id| { + if id == requester_id { + include_requester = true; + } + id.0 + }) + .collect::>(); + test_support!(self, { // Only return users that are in a common channel with the requesting user. + // Also allow the requesting user to return their own data, even if they aren't + // in any channels. let query = " - SELECT users.* + SELECT + users.* FROM users, channel_memberships WHERE @@ -142,11 +155,19 @@ impl Db { FROM channel_memberships WHERE channel_memberships.user_id = $2 ) + UNION + SELECT + users.* + FROM + users + WHERE + $3 AND users.id = $2 "; sqlx::query_as(query) - .bind(&ids.map(|id| id.0).collect::>()) + .bind(&ids) .bind(requester_id) + .bind(include_requester) .fetch_all(&self.pool) .await }) @@ -381,11 +402,13 @@ impl Db { sender_id: UserId, body: &str, timestamp: OffsetDateTime, + nonce: u128, ) -> Result { test_support!(self, { let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) - VALUES ($1, $2, $3, $4) + INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce RETURNING id "; sqlx::query_scalar(query) @@ -393,6 +416,7 @@ impl Db { .bind(sender_id.0) .bind(body) .bind(timestamp) + .bind(Uuid::from_u128(nonce)) .fetch_one(&self.pool) .await .map(MessageId) @@ -409,7 +433,7 @@ impl Db { let query = r#" SELECT * FROM ( SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at + id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce FROM channel_messages WHERE @@ -455,7 +479,7 @@ macro_rules! id_type { } id_type!(UserId); -#[derive(Debug, FromRow, Serialize)] +#[derive(Debug, FromRow, Serialize, PartialEq)] pub struct User { pub id: UserId, pub github_login: String, @@ -493,6 +517,7 @@ pub struct ChannelMessage { pub sender_id: UserId, pub body: String, pub sent_at: OffsetDateTime, + pub nonce: Uuid, } #[cfg(test)] @@ -563,6 +588,91 @@ pub mod tests { } } + #[gpui::test] + async fn test_get_users_by_ids() { + let test_db = TestDb::new(); + let db = test_db.db(); + + let user = db.create_user("user", false).await.unwrap(); + let friend1 = db.create_user("friend-1", false).await.unwrap(); + let friend2 = db.create_user("friend-2", false).await.unwrap(); + let friend3 = db.create_user("friend-3", false).await.unwrap(); + let stranger = db.create_user("stranger", false).await.unwrap(); + + // A user can read their own info, even if they aren't in any channels. + assert_eq!( + db.get_users_by_ids( + user, + [user, friend1, friend2, friend3, stranger].iter().copied() + ) + .await + .unwrap(), + vec![User { + id: user, + github_login: "user".to_string(), + admin: false, + },], + ); + + // A user can read the info of any other user who is in a shared channel + // with them. + let org = db.create_org("test org", "test-org").await.unwrap(); + let chan1 = db.create_org_channel(org, "channel-1").await.unwrap(); + let chan2 = db.create_org_channel(org, "channel-2").await.unwrap(); + let chan3 = db.create_org_channel(org, "channel-3").await.unwrap(); + + db.add_channel_member(chan1, user, false).await.unwrap(); + db.add_channel_member(chan2, user, false).await.unwrap(); + db.add_channel_member(chan1, friend1, false).await.unwrap(); + db.add_channel_member(chan1, friend2, false).await.unwrap(); + db.add_channel_member(chan2, friend2, false).await.unwrap(); + db.add_channel_member(chan2, friend3, false).await.unwrap(); + db.add_channel_member(chan3, stranger, false).await.unwrap(); + + assert_eq!( + db.get_users_by_ids( + user, + [user, friend1, friend2, friend3, stranger].iter().copied() + ) + .await + .unwrap(), + vec![ + User { + id: user, + github_login: "user".to_string(), + admin: false, + }, + User { + id: friend1, + github_login: "friend-1".to_string(), + admin: false, + }, + User { + id: friend2, + github_login: "friend-2".to_string(), + admin: false, + }, + User { + id: friend3, + github_login: "friend-3".to_string(), + admin: false, + } + ] + ); + + // The user's own info is only returned if they request it. + assert_eq!( + db.get_users_by_ids(user, [friend1].iter().copied()) + .await + .unwrap(), + vec![User { + id: friend1, + github_login: "friend-1".to_string(), + admin: false, + },] + ) + } + #[gpui::test] async fn test_recent_channel_messages() { let test_db = TestDb::new(); @@ -571,7 +681,7 @@ pub mod tests { let org = db.create_org("org", "org").await.unwrap(); let channel = db.create_org_channel(org, "channel").await.unwrap(); for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc()) + db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) .await .unwrap(); } @@ -591,4 +701,34 @@ pub mod tests { ["1", "2", "3", "4"] ); } + + #[gpui::test] + async fn test_channel_message_nonces() { + let test_db = TestDb::new(); + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + + let msg1_id = db + .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg2_id = db + .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + let msg3_id = db + .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg4_id = db + .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + + assert_ne!(msg1_id, msg2_id); + assert_eq!(msg1_id, msg3_id); + assert_eq!(msg2_id, msg4_id); + } } diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 1e0fe2465cafbe2d7034e941500145cda25ebd1d..debd982366c7a4b9d1339963612d9e101dfcff0f 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -602,6 +602,7 @@ impl Server { body: msg.body, timestamp: msg.sent_at.unix_timestamp() as u64, sender_id: msg.sender_id.to_proto(), + nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); self.peer @@ -687,10 +688,24 @@ impl Server { } let timestamp = OffsetDateTime::now_utc(); + let nonce = if let Some(nonce) = request.payload.nonce { + nonce + } else { + self.peer + .respond_with_error( + receipt, + proto::Error { + message: "nonce can't be blank".to_string(), + }, + ) + .await?; + return Ok(()); + }; + let message_id = self .app_state .db - .create_channel_message(channel_id, user_id, &body, timestamp) + .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into()) .await? .to_proto(); let message = proto::ChannelMessage { @@ -698,6 +713,7 @@ impl Server { id: message_id, body, timestamp: timestamp.unix_timestamp() as u64, + nonce: Some(nonce), }; broadcast(request.sender_id, connection_ids, |conn_id| { self.peer.send( @@ -754,6 +770,7 @@ impl Server { body: msg.body, timestamp: msg.sent_at.unix_timestamp() as u64, sender_id: msg.sender_id.to_proto(), + nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); self.peer @@ -1039,8 +1056,8 @@ mod tests { // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (client_a, _) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, _) = server.create_client(&mut cx_b, "user_b").await; cx_a.foreground().forbid_parking(); @@ -1124,7 +1141,7 @@ mod tests { .await; // Close the buffer as client A, see that the buffer is closed. - drop(buffer_a); + cx_a.update(move |_| drop(buffer_a)); worktree_a .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx)) .await; @@ -1147,9 +1164,9 @@ mod tests { // Connect to a server as 3 clients. let mut server = TestServer::start().await; - let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; - let (_, client_c) = server.create_client(&mut cx_c, "user_c").await; + let (client_a, _) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, _) = server.create_client(&mut cx_b, "user_b").await; + let (client_c, _) = server.create_client(&mut cx_c, "user_c").await; let fs = Arc::new(FakeFs::new()); @@ -1288,8 +1305,8 @@ mod tests { // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (client_a, _) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, _) = server.create_client(&mut cx_b, "user_b").await; // Share a local worktree as client A let fs = Arc::new(FakeFs::new()); @@ -1369,8 +1386,8 @@ mod tests { // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (client_a, _) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, _) = server.create_client(&mut cx_b, "user_b").await; // Share a local worktree as client A let fs = Arc::new(FakeFs::new()); @@ -1429,8 +1446,8 @@ mod tests { // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (_, client_b) = server.create_client(&mut cx_a, "user_b").await; + let (client_a, _) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, _) = server.create_client(&mut cx_a, "user_b").await; // Share a local worktree as client A let fs = Arc::new(FakeFs::new()); @@ -1484,38 +1501,40 @@ mod tests { #[gpui::test] async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); - let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) }); // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await; // Create an org that includes these 2 users. let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, user_id_a, false).await.unwrap(); - db.add_org_member(org_id, user_id_b, false).await.unwrap(); + db.add_org_member(org_id, current_user_id(&user_store_a), false) + .await + .unwrap(); + db.add_org_member(org_id, current_user_id(&user_store_b), false) + .await + .unwrap(); // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, user_id_a, false) + db.add_channel_member(channel_id, current_user_id(&user_store_a), false) .await .unwrap(); - db.add_channel_member(channel_id, user_id_b, false) + db.add_channel_member(channel_id, current_user_id(&user_store_b), false) .await .unwrap(); db.create_channel_message( channel_id, - user_id_b, + current_user_id(&user_store_b), "hello A, it's B.", OffsetDateTime::now_utc(), + 1, ) .await .unwrap(); - let user_store_a = - UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref()); let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); channels_a .condition(&mut cx_a, |list, _| list.available_channels().is_some()) @@ -1536,12 +1555,10 @@ mod tests { channel_a .condition(&cx_a, |channel, _| { channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] }) .await; - let user_store_b = - UserStore::new(client_b.clone(), http.clone(), cx_b.background().as_ref()); let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx)); channels_b .condition(&mut cx_b, |list, _| list.available_channels().is_some()) @@ -1563,7 +1580,7 @@ mod tests { channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] }) .await; @@ -1575,28 +1592,25 @@ mod tests { .detach(); let task = channel.send_message("sup".to_string(), cx).unwrap(); assert_eq!( - channel - .pending_messages() - .iter() - .map(|m| &m.body) - .collect::>(), - &["oh, hi B.", "sup"] + channel_messages(channel), + &[ + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), true), + ("user_a".to_string(), "sup".to_string(), true) + ] ); task }) .await .unwrap(); - channel_a - .condition(&cx_a, |channel, _| channel.pending_messages().is_empty()) - .await; channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) == [ - ("user_b".to_string(), "hello A, it's B.".to_string()), - ("user_a".to_string(), "oh, hi B.".to_string()), - ("user_a".to_string(), "sup".to_string()), + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), false), + ("user_a".to_string(), "sup".to_string(), false), ] }) .await; @@ -1616,33 +1630,25 @@ mod tests { server .condition(|state| !state.channels.contains_key(&channel_id)) .await; - - fn channel_messages(channel: &Channel) -> Vec<(String, String)> { - channel - .messages() - .cursor::<(), ()>() - .map(|m| (m.sender.github_login.clone(), m.body.clone())) - .collect() - } } #[gpui::test] async fn test_chat_message_validation(mut cx_a: TestAppContext) { cx_a.foreground().forbid_parking(); - let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) }); let mut server = TestServer::start().await; - let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await; let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_org_member(org_id, user_id_a, false).await.unwrap(); - db.add_channel_member(channel_id, user_id_a, false) + db.add_org_member(org_id, current_user_id(&user_store_a), false) + .await + .unwrap(); + db.add_channel_member(channel_id, current_user_id(&user_store_a), false) .await .unwrap(); - let user_store_a = UserStore::new(client_a.clone(), http, cx_a.background().as_ref()); let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); channels_a .condition(&mut cx_a, |list, _| list.available_channels().is_some()) @@ -1692,29 +1698,34 @@ mod tests { // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; - let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await; + let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await; let mut status_b = client_b.status(); // Create an org that includes these 2 users. let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, user_id_a, false).await.unwrap(); - db.add_org_member(org_id, user_id_b, false).await.unwrap(); + db.add_org_member(org_id, current_user_id(&user_store_a), false) + .await + .unwrap(); + db.add_org_member(org_id, current_user_id(&user_store_b), false) + .await + .unwrap(); // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, user_id_a, false) + db.add_channel_member(channel_id, current_user_id(&user_store_a), false) .await .unwrap(); - db.add_channel_member(channel_id, user_id_b, false) + db.add_channel_member(channel_id, current_user_id(&user_store_b), false) .await .unwrap(); db.create_channel_message( channel_id, - user_id_b, + current_user_id(&user_store_b), "hello A, it's B.", OffsetDateTime::now_utc(), + 2, ) .await .unwrap(); @@ -1742,13 +1753,11 @@ mod tests { channel_a .condition(&cx_a, |channel, _| { channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] }) .await; - let user_store_b = - UserStore::new(client_b.clone(), http.clone(), cx_b.background().as_ref()); - let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx)); + let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx)); channels_b .condition(&mut cx_b, |list, _| list.available_channels().is_some()) .await; @@ -1769,13 +1778,13 @@ mod tests { channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] }) .await; // Disconnect client B, ensuring we can still access its cached channel data. server.forbid_connections(); - server.disconnect_client(user_id_b); + server.disconnect_client(current_user_id(&user_store_b)); while !matches!( status_b.recv().await, Some(rpc::Status::ReconnectionError { .. }) @@ -1793,10 +1802,28 @@ mod tests { channel_b.read_with(&cx_b, |channel, _| { assert_eq!( channel_messages(channel), - [("user_b".to_string(), "hello A, it's B.".to_string())] + [("user_b".to_string(), "hello A, it's B.".to_string(), false)] ) }); + // Send a message from client B while it is disconnected. + channel_b + .update(&mut cx_b, |channel, cx| { + let task = channel + .send_message("can you see this?".to_string(), cx) + .unwrap(); + assert_eq!( + channel_messages(channel), + &[ + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), true) + ] + ); + task + }) + .await + .unwrap_err(); + // Send a message from client A while B is disconnected. channel_a .update(&mut cx_a, |channel, cx| { @@ -1806,12 +1833,12 @@ mod tests { .detach(); let task = channel.send_message("sup".to_string(), cx).unwrap(); assert_eq!( - channel - .pending_messages() - .iter() - .map(|m| &m.body) - .collect::>(), - &["oh, hi B.", "sup"] + channel_messages(channel), + &[ + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), true), + ("user_a".to_string(), "sup".to_string(), true) + ] ); task }) @@ -1822,14 +1849,16 @@ mod tests { server.allow_connections(); cx_b.foreground().advance_clock(Duration::from_secs(10)); - // Verify that B sees the new messages upon reconnection. + // Verify that B sees the new messages upon reconnection, as well as the message client B + // sent while offline. channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) == [ - ("user_b".to_string(), "hello A, it's B.".to_string()), - ("user_a".to_string(), "oh, hi B.".to_string()), - ("user_a".to_string(), "sup".to_string()), + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), false), + ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), ] }) .await; @@ -1845,10 +1874,11 @@ mod tests { .condition(&cx_b, |channel, _| { channel_messages(channel) == [ - ("user_b".to_string(), "hello A, it's B.".to_string()), - ("user_a".to_string(), "oh, hi B.".to_string()), - ("user_a".to_string(), "sup".to_string()), - ("user_a".to_string(), "you online?".to_string()), + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), false), + ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), + ("user_a".to_string(), "you online?".to_string(), false), ] }) .await; @@ -1863,22 +1893,15 @@ mod tests { .condition(&cx_a, |channel, _| { channel_messages(channel) == [ - ("user_b".to_string(), "hello A, it's B.".to_string()), - ("user_a".to_string(), "oh, hi B.".to_string()), - ("user_a".to_string(), "sup".to_string()), - ("user_a".to_string(), "you online?".to_string()), - ("user_b".to_string(), "yep".to_string()), + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_a".to_string(), "oh, hi B.".to_string(), false), + ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), + ("user_a".to_string(), "you online?".to_string(), false), + ("user_b".to_string(), "yep".to_string(), false), ] }) .await; - - fn channel_messages(channel: &Channel) -> Vec<(String, String)> { - channel - .messages() - .cursor::<(), ()>() - .map(|m| (m.sender.github_login.clone(), m.body.clone())) - .collect() - } } struct TestServer { @@ -1913,8 +1936,8 @@ mod tests { &mut self, cx: &mut TestAppContext, name: &str, - ) -> (UserId, Arc) { - let client_user_id = self.app_state.db.create_user(name, false).await.unwrap(); + ) -> (Arc, Arc) { + let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let client_name = name.to_string(); let mut client = Client::new(); let server = self.server.clone(); @@ -1926,13 +1949,13 @@ mod tests { cx.spawn(|_| async move { let access_token = "the-token".to_string(); Ok(Credentials { - user_id: client_user_id.0 as u64, + user_id: user_id.0 as u64, access_token, }) }) }) .override_establish_connection(move |credentials, cx| { - assert_eq!(credentials.user_id, client_user_id.0 as u64); + assert_eq!(credentials.user_id, user_id.0 as u64); assert_eq!(credentials.access_token, "the-token"); let server = server.clone(); @@ -1946,24 +1969,26 @@ mod tests { ))) } else { let (client_conn, server_conn, kill_conn) = Connection::in_memory(); - connection_killers.lock().insert(client_user_id, kill_conn); + connection_killers.lock().insert(user_id, kill_conn); cx.background() - .spawn(server.handle_connection( - server_conn, - client_name, - client_user_id, - )) + .spawn(server.handle_connection(server_conn, client_name, user_id)) .detach(); Ok(client_conn) } }) }); + let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) }); client .authenticate_and_connect(&cx.to_async()) .await .unwrap(); - (client_user_id, client) + + let user_store = UserStore::new(client.clone(), http, &cx.background()); + let mut authed_user = user_store.watch_current_user(); + while authed_user.recv().await.unwrap().is_none() {} + + (client, user_store) } fn disconnect_client(&self, user_id: UserId) { @@ -2019,6 +2044,24 @@ mod tests { } } + fn current_user_id(user_store: &Arc) -> UserId { + UserId::from_proto(user_store.current_user().unwrap().id) + } + + fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { + channel + .messages() + .cursor::<(), ()>() + .map(|m| { + ( + m.sender.github_login.clone(), + m.body.clone(), + m.is_pending(), + ) + }) + .collect() + } + struct EmptyView; impl gpui::Entity for EmptyView { diff --git a/zed/Cargo.toml b/zed/Cargo.toml index 63f64e4a5f26a991231611f678cb0433ca6a8e0b..8d27fcd4c540b58544dd6223960b74409afcae86 100644 --- a/zed/Cargo.toml +++ b/zed/Cargo.toml @@ -32,6 +32,7 @@ gpui = { path = "../gpui" } http-auth-basic = "0.1.3" ignore = "0.4" image = "0.23" +indexmap = "1.6.2" lazy_static = "1.4.0" libc = "0.2" log = "0.4" diff --git a/zed/assets/themes/_base.toml b/zed/assets/themes/_base.toml index 485d3bf2a79e7269fb24462c1160d370c7bf1b89..1a2999379c4e46f82514349cddeee38ecb697467 100644 --- a/zed/assets/themes/_base.toml +++ b/zed/assets/themes/_base.toml @@ -67,6 +67,12 @@ sender = { extends = "$text.0", weight = "bold", margin.right = 8 } timestamp = "$text.2" padding.bottom = 6 +[chat_panel.pending_message] +extends = "$chat_panel.message" +body = { color = "$text.3.color" } +sender = { color = "$text.3.color" } +timestamp = { color = "$text.3.color" } + [chat_panel.channel_select.item] padding = 4 name = "$text.1" diff --git a/zed/assets/themes/black.toml b/zed/assets/themes/black.toml index 53d9957f4b20b124f01460f5ccdc2a687444db87..3a7319e2a9e6af9b1f101981bfc8f107267b0c50 100644 --- a/zed/assets/themes/black.toml +++ b/zed/assets/themes/black.toml @@ -9,7 +9,6 @@ extends = "_base" 0 = "#0F1011" [text] -base = { family = "Inconsolata", size = 15 } 0 = { extends = "$text.base", color = "#ffffff" } 1 = { extends = "$text.base", color = "#b3b3b3" } 2 = { extends = "$text.base", color = "#7b7d80" } @@ -49,4 +48,4 @@ number = "#b5cea8" comment = "#6a9955" property = "#4e94ce" variant = "#4fc1ff" -constant = "#9cdcfe" \ No newline at end of file +constant = "#9cdcfe" diff --git a/zed/assets/themes/dark.toml b/zed/assets/themes/dark.toml index cf17c62fdbed355397b727fcad1c5de9e02ec9d3..f9c5a97f2acd9a3b40bf92e254ce9b16ff9b9688 100644 --- a/zed/assets/themes/dark.toml +++ b/zed/assets/themes/dark.toml @@ -9,7 +9,6 @@ extends = "_base" 0 = "#1B222B" [text] -base = { family = "Inconsolata", size = 15 } 0 = { extends = "$text.base", color = "#FFFFFF" } 1 = { extends = "$text.base", color = "#CDD1E2" } 2 = { extends = "$text.base", color = "#9BA8BE" } diff --git a/zed/assets/themes/light.toml b/zed/assets/themes/light.toml index 80f84f998c1981d453b3d793298b3d5afdba0397..fe3262b12ca295168d14fe7e37cea069932562f0 100644 --- a/zed/assets/themes/light.toml +++ b/zed/assets/themes/light.toml @@ -9,7 +9,6 @@ extends = "_base" 0 = "#DDDDDC" [text] -base = { family = "Inconsolata", size = 15 } 0 = { extends = "$text.base", color = "#000000" } 1 = { extends = "$text.base", color = "#29292B" } 2 = { extends = "$text.base", color = "#7E7E83" } diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 0c3cb8bd2496dbdff7c618871237f72ff0027dd7..c43cf2e6f7b28a45e5a69dfa67c0383e065f6143 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -1,7 +1,7 @@ use crate::{ rpc::{self, Client}, user::{User, UserStore}, - util::TryFutureExt, + util::{post_inc, TryFutureExt}, }; use anyhow::{anyhow, Context, Result}; use gpui::{ @@ -9,6 +9,7 @@ use gpui::{ Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, }; use postage::prelude::Stream; +use rand::prelude::*; use std::{ collections::{HashMap, HashSet}, mem, @@ -39,29 +40,31 @@ pub struct Channel { details: ChannelDetails, messages: SumTree, loaded_all_messages: bool, - pending_messages: Vec, - next_local_message_id: u64, + next_pending_message_id: usize, user_store: Arc, rpc: Arc, + rng: StdRng, _subscription: rpc::Subscription, } #[derive(Clone, Debug)] pub struct ChannelMessage { - pub id: u64, + pub id: ChannelMessageId, pub body: String, pub timestamp: OffsetDateTime, pub sender: Arc, + pub nonce: u128, } -pub struct PendingChannelMessage { - pub body: String, - local_id: u64, +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum ChannelMessageId { + Saved(u64), + Pending(usize), } #[derive(Clone, Debug, Default)] pub struct ChannelMessageSummary { - max_id: u64, + max_id: ChannelMessageId, count: usize, } @@ -216,9 +219,9 @@ impl Channel { user_store, rpc, messages: Default::default(), - pending_messages: Default::default(), loaded_all_messages: false, - next_local_message_id: 0, + next_pending_message_id: 0, + rng: StdRng::from_entropy(), _subscription, } } @@ -236,17 +239,35 @@ impl Channel { Err(anyhow!("message body can't be empty"))?; } + let current_user = self + .user_store + .current_user() + .ok_or_else(|| anyhow!("current_user is not present"))?; + let channel_id = self.details.id; - let local_id = self.next_local_message_id; - self.next_local_message_id += 1; - self.pending_messages.push(PendingChannelMessage { - local_id, - body: body.clone(), - }); + let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); + let nonce = self.rng.gen(); + self.insert_messages( + SumTree::from_item( + ChannelMessage { + id: pending_id, + body: body.clone(), + sender: current_user, + timestamp: OffsetDateTime::now_utc(), + nonce, + }, + &(), + ), + cx, + ); let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); Ok(cx.spawn(|this, mut cx| async move { - let request = rpc.request(proto::SendChannelMessage { channel_id, body }); + let request = rpc.request(proto::SendChannelMessage { + channel_id, + body, + nonce: Some(nonce.into()), + }); let response = request.await?; let message = ChannelMessage::from_proto( response.message.ok_or_else(|| anyhow!("invalid message"))?, @@ -254,13 +275,7 @@ impl Channel { ) .await?; this.update(&mut cx, |this, cx| { - if let Ok(i) = this - .pending_messages - .binary_search_by_key(&local_id, |msg| msg.local_id) - { - this.pending_messages.remove(i); - this.insert_messages(SumTree::from_item(message, &()), cx); - } + this.insert_messages(SumTree::from_item(message, &()), cx); Ok(()) }) })) @@ -271,7 +286,12 @@ impl Channel { let rpc = self.rpc.clone(); let user_store = self.user_store.clone(); let channel_id = self.details.id; - if let Some(before_message_id) = self.messages.first().map(|message| message.id) { + if let Some(before_message_id) = + self.messages.first().and_then(|message| match message.id { + ChannelMessageId::Saved(id) => Some(id), + ChannelMessageId::Pending(_) => None, + }) + { cx.spawn(|this, mut cx| { async move { let response = rpc @@ -301,32 +321,51 @@ impl Channel { let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); let channel_id = self.details.id; - cx.spawn(|channel, mut cx| { + cx.spawn(|this, mut cx| { async move { let response = rpc.request(proto::JoinChannel { channel_id }).await?; let messages = messages_from_proto(response.messages, &user_store).await?; let loaded_all_messages = response.done; - channel.update(&mut cx, |channel, cx| { + let pending_messages = this.update(&mut cx, |this, cx| { if let Some((first_new_message, last_old_message)) = - messages.first().zip(channel.messages.last()) + messages.first().zip(this.messages.last()) { if first_new_message.id > last_old_message.id { - let old_messages = mem::take(&mut channel.messages); + let old_messages = mem::take(&mut this.messages); cx.emit(ChannelEvent::MessagesUpdated { old_range: 0..old_messages.summary().count, new_count: 0, }); - channel.loaded_all_messages = loaded_all_messages; + this.loaded_all_messages = loaded_all_messages; } } - channel.insert_messages(messages, cx); + this.insert_messages(messages, cx); if loaded_all_messages { - channel.loaded_all_messages = loaded_all_messages; + this.loaded_all_messages = loaded_all_messages; } + + this.pending_messages().cloned().collect::>() }); + for pending_message in pending_messages { + let request = rpc.request(proto::SendChannelMessage { + channel_id, + body: pending_message.body, + nonce: Some(pending_message.nonce.into()), + }); + let response = request.await?; + let message = ChannelMessage::from_proto( + response.message.ok_or_else(|| anyhow!("invalid message"))?, + &user_store, + ) + .await?; + this.update(&mut cx, |this, cx| { + this.insert_messages(SumTree::from_item(message, &()), cx); + }); + } + Ok(()) } .log_err() @@ -354,8 +393,10 @@ impl Channel { cursor.take(range.len()) } - pub fn pending_messages(&self) -> &[PendingChannelMessage] { - &self.pending_messages + pub fn pending_messages(&self) -> impl Iterator { + let mut cursor = self.messages.cursor::(); + cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); + cursor } fn handle_message_sent( @@ -386,7 +427,12 @@ impl Channel { fn insert_messages(&mut self, messages: SumTree, cx: &mut ModelContext) { if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { - let mut old_cursor = self.messages.cursor::(); + let nonces = messages + .cursor::<(), ()>() + .map(|m| m.nonce) + .collect::>(); + + let mut old_cursor = self.messages.cursor::(); let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); let start_ix = old_cursor.sum_start().0; let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); @@ -395,10 +441,40 @@ impl Channel { let end_ix = start_ix + removed_count; new_messages.push_tree(messages, &()); - new_messages.push_tree(old_cursor.suffix(&()), &()); + + let mut ranges = Vec::>::new(); + if new_messages.last().unwrap().is_pending() { + new_messages.push_tree(old_cursor.suffix(&()), &()); + } else { + new_messages.push_tree( + old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), + &(), + ); + + while let Some(message) = old_cursor.item() { + let message_ix = old_cursor.sum_start().0; + if nonces.contains(&message.nonce) { + if ranges.last().map_or(false, |r| r.end == message_ix) { + ranges.last_mut().unwrap().end += 1; + } else { + ranges.push(message_ix..message_ix + 1); + } + } else { + new_messages.push(message.clone(), &()); + } + old_cursor.next(&()); + } + } + drop(old_cursor); self.messages = new_messages; + for range in ranges.into_iter().rev() { + cx.emit(ChannelEvent::MessagesUpdated { + old_range: range, + new_count: 0, + }); + } cx.emit(ChannelEvent::MessagesUpdated { old_range: start_ix..end_ix, new_count, @@ -445,12 +521,20 @@ impl ChannelMessage { ) -> Result { let sender = user_store.fetch_user(message.sender_id).await?; Ok(ChannelMessage { - id: message.id, + id: ChannelMessageId::Saved(message.id), body: message.body, timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, + nonce: message + .nonce + .ok_or_else(|| anyhow!("nonce is required"))? + .into(), }) } + + pub fn is_pending(&self) -> bool { + matches!(self.id, ChannelMessageId::Pending(_)) + } } impl sum_tree::Item for ChannelMessage { @@ -464,6 +548,12 @@ impl sum_tree::Item for ChannelMessage { } } +impl Default for ChannelMessageId { + fn default() -> Self { + Self::Saved(0) + } +} + impl sum_tree::Summary for ChannelMessageSummary { type Context = (); @@ -473,7 +563,7 @@ impl sum_tree::Summary for ChannelMessageSummary { } } -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for u64 { +impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId { fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { debug_assert!(summary.max_id > *self); *self = summary.max_id; @@ -568,12 +658,14 @@ mod tests { body: "a".into(), timestamp: 1000, sender_id: 5, + nonce: Some(1.into()), }, proto::ChannelMessage { id: 11, body: "b".into(), timestamp: 1001, sender_id: 6, + nonce: Some(2.into()), }, ], done: false, @@ -627,6 +719,7 @@ mod tests { body: "c".into(), timestamp: 1002, sender_id: 7, + nonce: Some(3.into()), }), }) .await; @@ -682,12 +775,14 @@ mod tests { body: "y".into(), timestamp: 998, sender_id: 5, + nonce: Some(4.into()), }, proto::ChannelMessage { id: 9, body: "z".into(), timestamp: 999, sender_id: 6, + nonce: Some(5.into()), }, ], }, diff --git a/zed/src/chat_panel.rs b/zed/src/chat_panel.rs index 5ccc014ca787119f6866254b20adc1ae79cab98f..d7752b9a53e944b6e5c1776ab280fd3ea95f025b 100644 --- a/zed/src/chat_panel.rs +++ b/zed/src/chat_panel.rs @@ -230,7 +230,12 @@ impl ChatPanel { fn render_message(&self, message: &ChannelMessage) -> ElementBox { let now = OffsetDateTime::now_utc(); let settings = self.settings.borrow(); - let theme = &settings.theme.chat_panel.message; + let theme = if message.is_pending() { + &settings.theme.chat_panel.pending_message + } else { + &settings.theme.chat_panel.message + }; + Container::new( Flex::column() .with_child( @@ -381,9 +386,10 @@ impl View for ChatPanel { fn render(&mut self, cx: &mut RenderContext) -> ElementBox { let theme = &self.settings.borrow().theme; - let element = match *self.rpc.status().borrow() { - rpc::Status::Connected { .. } => self.render_channel(), - _ => self.render_sign_in_prompt(cx), + let element = if self.rpc.user_id().is_some() { + self.render_channel() + } else { + self.render_sign_in_prompt(cx) }; ConstrainedBox::new( Container::new(element) diff --git a/zed/src/theme.rs b/zed/src/theme.rs index 88f385a05461b18a0db3f6182e9b3d08effb97b2..a96945fecc1011d9c1de9ef941560f863350491d 100644 --- a/zed/src/theme.rs +++ b/zed/src/theme.rs @@ -1,4 +1,5 @@ mod highlight_map; +mod resolution; mod theme_registry; use anyhow::Result; @@ -95,6 +96,7 @@ pub struct ChatPanel { #[serde(flatten)] pub container: ContainerStyle, pub message: ChatMessage, + pub pending_message: ChatMessage, pub channel_select: ChannelSelect, pub input_editor: InputEditorStyle, pub sign_in_prompt: TextStyle, diff --git a/zed/src/theme/resolution.rs b/zed/src/theme/resolution.rs new file mode 100644 index 0000000000000000000000000000000000000000..fd3864e274af20f75fbfe4da54f43fcdcdecc6c3 --- /dev/null +++ b/zed/src/theme/resolution.rs @@ -0,0 +1,476 @@ +use anyhow::{anyhow, Result}; +use indexmap::IndexMap; +use serde_json::Value; +use std::{ + cell::RefCell, + mem, + rc::{Rc, Weak}, +}; + +pub fn resolve_references(value: Value) -> Result { + let tree = Tree::from_json(value)?; + tree.resolve()?; + tree.to_json() +} + +#[derive(Clone)] +enum Node { + Reference { + path: String, + parent: Option>>, + }, + Object { + base: Option, + children: IndexMap, + resolved: bool, + parent: Option>>, + }, + Array { + children: Vec, + resolved: bool, + parent: Option>>, + }, + String { + value: String, + parent: Option>>, + }, + Number { + value: serde_json::Number, + parent: Option>>, + }, + Bool { + value: bool, + parent: Option>>, + }, + Null { + parent: Option>>, + }, +} + +#[derive(Clone)] +struct Tree(Rc>); + +impl Tree { + pub fn new(node: Node) -> Self { + Self(Rc::new(RefCell::new(node))) + } + + fn from_json(value: Value) -> Result { + match value { + Value::String(value) => { + if let Some(path) = value.strip_prefix("$") { + Ok(Self::new(Node::Reference { + path: path.to_string(), + parent: None, + })) + } else { + Ok(Self::new(Node::String { + value, + parent: None, + })) + } + } + Value::Number(value) => Ok(Self::new(Node::Number { + value, + parent: None, + })), + Value::Bool(value) => Ok(Self::new(Node::Bool { + value, + parent: None, + })), + Value::Null => Ok(Self::new(Node::Null { parent: None })), + Value::Object(object) => { + let tree = Self::new(Node::Object { + base: Default::default(), + children: Default::default(), + resolved: false, + parent: None, + }); + let mut children = IndexMap::new(); + let mut resolved = true; + let mut base = None; + for (key, value) in object.into_iter() { + let value = if key == "extends" { + if value.is_string() { + if let Value::String(value) = value { + base = value.strip_prefix("$").map(str::to_string); + resolved = false; + Self::new(Node::String { + value, + parent: None, + }) + } else { + unreachable!() + } + } else { + Tree::from_json(value)? + } + } else { + Tree::from_json(value)? + }; + value + .0 + .borrow_mut() + .set_parent(Some(Rc::downgrade(&tree.0))); + resolved &= value.is_resolved(); + children.insert(key.clone(), value); + } + + *tree.0.borrow_mut() = Node::Object { + base, + children, + resolved, + parent: None, + }; + Ok(tree) + } + Value::Array(elements) => { + let tree = Self::new(Node::Array { + children: Default::default(), + resolved: false, + parent: None, + }); + + let mut children = Vec::new(); + let mut resolved = true; + for element in elements { + let child = Tree::from_json(element)?; + child + .0 + .borrow_mut() + .set_parent(Some(Rc::downgrade(&tree.0))); + resolved &= child.is_resolved(); + children.push(child); + } + + *tree.0.borrow_mut() = Node::Array { + children, + resolved, + parent: None, + }; + Ok(tree) + } + } + } + + fn to_json(&self) -> Result { + match &*self.0.borrow() { + Node::Reference { .. } => Err(anyhow!("unresolved tree")), + Node::String { value, .. } => Ok(Value::String(value.clone())), + Node::Number { value, .. } => Ok(Value::Number(value.clone())), + Node::Bool { value, .. } => Ok(Value::Bool(*value)), + Node::Null { .. } => Ok(Value::Null), + Node::Object { children, .. } => { + let mut json_children = serde_json::Map::new(); + for (key, value) in children { + json_children.insert(key.clone(), value.to_json()?); + } + Ok(Value::Object(json_children)) + } + Node::Array { children, .. } => { + let mut json_children = Vec::new(); + for child in children { + json_children.push(child.to_json()?); + } + Ok(Value::Array(json_children)) + } + } + } + + fn parent(&self) -> Option { + match &*self.0.borrow() { + Node::Reference { parent, .. } + | Node::Object { parent, .. } + | Node::Array { parent, .. } + | Node::String { parent, .. } + | Node::Number { parent, .. } + | Node::Bool { parent, .. } + | Node::Null { parent } => parent.as_ref().and_then(|p| p.upgrade()).map(Tree), + } + } + + fn get(&self, path: &str) -> Result> { + let mut tree = self.clone(); + for component in path.split('.') { + let node = tree.0.borrow(); + match &*node { + Node::Object { children, .. } => { + if let Some(subtree) = children.get(component).cloned() { + drop(node); + tree = subtree; + } else { + return Err(anyhow!( + "key \"{}\" does not exist in path \"{}\"", + component, + path + )); + } + } + Node::Reference { .. } => return Ok(None), + Node::Array { .. } + | Node::String { .. } + | Node::Number { .. } + | Node::Bool { .. } + | Node::Null { .. } => { + return Err(anyhow!( + "key \"{}\" in path \"{}\" is not an object", + component, + path + )) + } + } + } + + Ok(Some(tree)) + } + + fn is_resolved(&self) -> bool { + match &*self.0.borrow() { + Node::Reference { .. } => false, + Node::Object { resolved, .. } | Node::Array { resolved, .. } => *resolved, + Node::String { .. } | Node::Number { .. } | Node::Bool { .. } | Node::Null { .. } => { + true + } + } + } + + fn update_resolved(&self) { + match &mut *self.0.borrow_mut() { + Node::Object { + resolved, children, .. + } => { + *resolved = children.values().all(|c| c.is_resolved()); + } + Node::Array { + resolved, children, .. + } => { + *resolved = children.iter().all(|c| c.is_resolved()); + } + _ => {} + } + } + + pub fn resolve(&self) -> Result<()> { + let mut unresolved = vec![self.clone()]; + let mut made_progress = true; + + while made_progress && !unresolved.is_empty() { + made_progress = false; + for mut tree in mem::take(&mut unresolved) { + made_progress |= tree.resolve_subtree(self, &mut unresolved)?; + if tree.is_resolved() { + while let Some(parent) = tree.parent() { + parent.update_resolved(); + tree = parent; + } + } + } + } + + if unresolved.is_empty() { + Ok(()) + } else { + Err(anyhow!("tree contains cycles")) + } + } + + fn resolve_subtree(&self, root: &Tree, unresolved: &mut Vec) -> Result { + let node = self.0.borrow(); + match &*node { + Node::Reference { path, parent } => { + if let Some(subtree) = root.get(&path)? { + if subtree.is_resolved() { + let parent = parent.clone(); + drop(node); + let mut new_node = subtree.0.borrow().clone(); + new_node.set_parent(parent); + *self.0.borrow_mut() = new_node; + Ok(true) + } else { + unresolved.push(self.clone()); + Ok(false) + } + } else { + unresolved.push(self.clone()); + Ok(false) + } + } + Node::Object { + base, + children, + resolved, + .. + } => { + if *resolved { + Ok(false) + } else { + let mut made_progress = false; + let mut children_resolved = true; + for child in children.values() { + made_progress |= child.resolve_subtree(root, unresolved)?; + children_resolved &= child.is_resolved(); + } + + if children_resolved { + let mut has_base = false; + let mut resolved_base = None; + if let Some(base) = base { + has_base = true; + if let Some(base) = root.get(base)? { + if base.is_resolved() { + resolved_base = Some(base); + } + } + } + + drop(node); + + if let Some(base) = resolved_base.as_ref() { + self.extend_from(&base); + made_progress = true; + } + + if let Node::Object { resolved, .. } = &mut *self.0.borrow_mut() { + if has_base { + if resolved_base.is_some() { + *resolved = true; + } else { + unresolved.push(self.clone()); + } + } else { + *resolved = true; + } + } + } + + Ok(made_progress) + } + } + Node::Array { + children, resolved, .. + } => { + if *resolved { + Ok(false) + } else { + let mut made_progress = false; + let mut children_resolved = true; + for child in children.iter() { + made_progress |= child.resolve_subtree(root, unresolved)?; + children_resolved &= child.is_resolved(); + } + + if children_resolved { + drop(node); + + if let Node::Array { resolved, .. } = &mut *self.0.borrow_mut() { + *resolved = true; + } + } + + Ok(made_progress) + } + } + Node::String { .. } | Node::Number { .. } | Node::Bool { .. } | Node::Null { .. } => { + Ok(false) + } + } + } + + fn extend_from(&self, base: &Tree) { + if Rc::ptr_eq(&self.0, &base.0) { + return; + } + + if let ( + Node::Object { children, .. }, + Node::Object { + children: base_children, + .. + }, + ) = (&mut *self.0.borrow_mut(), &*base.0.borrow()) + { + for (key, base_value) in base_children { + if let Some(value) = children.get(key) { + value.extend_from(base_value); + } else { + let base_value = base_value.clone(); + base_value + .0 + .borrow_mut() + .set_parent(Some(Rc::downgrade(&self.0))); + children.insert(key.clone(), base_value); + } + } + } + } +} + +impl Node { + fn set_parent(&mut self, new_parent: Option>>) { + match self { + Node::Reference { parent, .. } + | Node::Object { parent, .. } + | Node::Array { parent, .. } + | Node::String { parent, .. } + | Node::Number { parent, .. } + | Node::Bool { parent, .. } + | Node::Null { parent } => *parent = new_parent, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_references() { + let json = serde_json::json!({ + "a": { + "x": "$b.d" + }, + "b": { + "c": "$a", + "d": "$e.f" + }, + "e": { + "extends": "$a", + "f": "1" + } + }); + + assert_eq!( + resolve_references(json).unwrap(), + serde_json::json!({ + "a": { + "x": "1" + }, + "b": { + "c": { + "x": "1" + }, + "d": "1" + }, + "e": { + "extends": "$a", + "f": "1", + "x": "1" + }, + }) + ) + } + + #[test] + fn test_cycles() { + let json = serde_json::json!({ + "a": { + "b": "$c.d" + }, + "c": { + "d": "$a.b", + }, + }); + + assert!(resolve_references(json).is_err()); + } +} diff --git a/zed/src/theme/theme_registry.rs b/zed/src/theme/theme_registry.rs index cd9781afe942f4e5bee4d1c16b6aa5886ae376a0..c5cf8f2fcbd856a7e0a5419f5337e8c198aaca59 100644 --- a/zed/src/theme/theme_registry.rs +++ b/zed/src/theme/theme_registry.rs @@ -1,8 +1,9 @@ -use anyhow::{anyhow, Context, Result}; +use super::resolution::resolve_references; +use anyhow::{Context, Result}; use gpui::{fonts, AssetSource, FontCache}; use parking_lot::Mutex; use serde_json::{Map, Value}; -use std::{collections::HashMap, fmt, mem, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use super::Theme; @@ -13,30 +14,6 @@ pub struct ThemeRegistry { font_cache: Arc, } -#[derive(Default)] -struct KeyPathReferenceSet { - references: Vec, - reference_ids_by_source: Vec, - reference_ids_by_target: Vec, - dependencies: Vec<(usize, usize)>, - dependency_counts: Vec, -} - -#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord)] -struct KeyPathReference { - target: KeyPath, - source: KeyPath, -} - -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct KeyPath(Vec); - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -enum Key { - Array(usize), - Object(String), -} - impl ThemeRegistry { pub fn new(source: impl AssetSource, font_cache: Arc) -> Arc { Arc::new(Self { @@ -111,41 +88,15 @@ impl ThemeRegistry { } } + let mut theme_data = Value::Object(theme_data); + // Find all of the key path references in the object, and then sort them according // to their dependencies. if evaluate_references { - let mut key_path = KeyPath::default(); - let mut references = KeyPathReferenceSet::default(); - for (key, value) in theme_data.iter() { - key_path.0.push(Key::Object(key.clone())); - find_references(value, &mut key_path, &mut references); - key_path.0.pop(); - } - let sorted_references = references - .top_sort() - .map_err(|key_paths| anyhow!("cycle for key paths: {:?}", key_paths))?; - - // Now update objects to include the fields of objects they extend - for KeyPathReference { source, target } in sorted_references { - if let Some(source) = value_at(&mut theme_data, &source).cloned() { - let target = value_at(&mut theme_data, &target).unwrap(); - if let Value::Object(target_object) = target.take() { - if let Value::Object(mut source_object) = source { - deep_merge_json(&mut source_object, target_object); - *target = Value::Object(source_object); - } else { - Err(anyhow!("extended key path {} is not an object", source))?; - } - } else { - *target = source; - } - } else { - Err(anyhow!("invalid key path '{}'", source))?; - } - } + theme_data = resolve_references(theme_data)?; } - let result = Arc::new(Value::Object(theme_data)); + let result = Arc::new(theme_data); self.theme_data .lock() .insert(name.to_string(), result.clone()); @@ -154,293 +105,6 @@ impl ThemeRegistry { } } -impl KeyPathReferenceSet { - fn insert(&mut self, reference: KeyPathReference) { - let id = self.references.len(); - let source_ix = self - .reference_ids_by_source - .binary_search_by_key(&&reference.source, |id| &self.references[*id].source) - .unwrap_or_else(|i| i); - let target_ix = self - .reference_ids_by_target - .binary_search_by_key(&&reference.target, |id| &self.references[*id].target) - .unwrap_or_else(|i| i); - - self.populate_dependencies(id, &reference); - self.reference_ids_by_source.insert(source_ix, id); - self.reference_ids_by_target.insert(target_ix, id); - self.references.push(reference); - } - - fn top_sort(mut self) -> Result, Vec> { - let mut results = Vec::with_capacity(self.references.len()); - let mut root_ids = Vec::with_capacity(self.references.len()); - - // Find the initial set of references that have no dependencies. - for (id, dep_count) in self.dependency_counts.iter().enumerate() { - if *dep_count == 0 { - root_ids.push(id); - } - } - - while results.len() < root_ids.len() { - // Just to guarantee a stable result when the inputs are randomized, - // sort references lexicographically in absence of any dependency relationship. - root_ids[results.len()..].sort_by_key(|id| &self.references[*id]); - - let root_id = root_ids[results.len()]; - let root = mem::take(&mut self.references[root_id]); - results.push(root); - - // Remove this reference as a dependency from any of its dependent references. - if let Ok(dep_ix) = self - .dependencies - .binary_search_by_key(&root_id, |edge| edge.0) - { - let mut first_dep_ix = dep_ix; - let mut last_dep_ix = dep_ix + 1; - while first_dep_ix > 0 && self.dependencies[first_dep_ix - 1].0 == root_id { - first_dep_ix -= 1; - } - while last_dep_ix < self.dependencies.len() - && self.dependencies[last_dep_ix].0 == root_id - { - last_dep_ix += 1; - } - - // If any reference no longer has any dependencies, then then mark it as a root. - // Preserve the references' original order where possible. - for (_, successor_id) in self.dependencies.drain(first_dep_ix..last_dep_ix) { - self.dependency_counts[successor_id] -= 1; - if self.dependency_counts[successor_id] == 0 { - root_ids.push(successor_id); - } - } - } - } - - // If any references never became roots, then there are reference cycles - // in the set. Return an error containing all of the key paths that are - // directly involved in cycles. - if results.len() < self.references.len() { - let mut cycle_ref_ids = (0..self.references.len()) - .filter(|id| !root_ids.contains(id)) - .collect::>(); - - // Iteratively remove any references that have no dependencies, - // so that the error will only indicate which key paths are directly - // involved in the cycles. - let mut done = false; - while !done { - done = true; - cycle_ref_ids.retain(|id| { - if self.dependencies.iter().any(|dep| dep.0 == *id) { - true - } else { - done = false; - self.dependencies.retain(|dep| dep.1 != *id); - false - } - }); - } - - let mut cycle_key_paths = Vec::new(); - for id in cycle_ref_ids { - let reference = &self.references[id]; - cycle_key_paths.push(reference.target.clone()); - cycle_key_paths.push(reference.source.clone()); - } - cycle_key_paths.sort_unstable(); - return Err(cycle_key_paths); - } - - Ok(results) - } - - fn populate_dependencies(&mut self, new_id: usize, new_reference: &KeyPathReference) { - self.dependency_counts.push(0); - - // If an existing reference's source path starts with the new reference's - // target path, then insert this new reference before that existing reference. - for id in Self::reference_ids_for_key_path( - &new_reference.target.0, - &self.references, - &self.reference_ids_by_source, - KeyPathReference::source, - KeyPath::starts_with, - ) { - Self::add_dependency( - (new_id, id), - &mut self.dependencies, - &mut self.dependency_counts, - ); - } - - // If an existing reference's target path starts with the new reference's - // source path, then insert this new reference after that existing reference. - for id in Self::reference_ids_for_key_path( - &new_reference.source.0, - &self.references, - &self.reference_ids_by_target, - KeyPathReference::target, - KeyPath::starts_with, - ) { - Self::add_dependency( - (id, new_id), - &mut self.dependencies, - &mut self.dependency_counts, - ); - } - - // If an existing reference's source path is a prefix of the new reference's - // target path, then insert this new reference before that existing reference. - for prefix in new_reference.target.prefixes() { - for id in Self::reference_ids_for_key_path( - prefix, - &self.references, - &self.reference_ids_by_source, - KeyPathReference::source, - PartialEq::eq, - ) { - Self::add_dependency( - (new_id, id), - &mut self.dependencies, - &mut self.dependency_counts, - ); - } - } - - // If an existing reference's target path is a prefix of the new reference's - // source path, then insert this new reference after that existing reference. - for prefix in new_reference.source.prefixes() { - for id in Self::reference_ids_for_key_path( - prefix, - &self.references, - &self.reference_ids_by_target, - KeyPathReference::target, - PartialEq::eq, - ) { - Self::add_dependency( - (id, new_id), - &mut self.dependencies, - &mut self.dependency_counts, - ); - } - } - } - - // Find all existing references that satisfy a given predicate with respect - // to a given key path. Use a sorted array of reference ids in order to avoid - // performing unnecessary comparisons. - fn reference_ids_for_key_path<'a>( - key_path: &[Key], - references: &[KeyPathReference], - sorted_reference_ids: &'a [usize], - reference_attribute: impl Fn(&KeyPathReference) -> &KeyPath, - predicate: impl Fn(&KeyPath, &[Key]) -> bool, - ) -> impl Iterator + 'a { - let ix = sorted_reference_ids - .binary_search_by_key(&key_path, |id| &reference_attribute(&references[*id]).0) - .unwrap_or_else(|i| i); - - let mut start_ix = ix; - while start_ix > 0 { - let reference_id = sorted_reference_ids[start_ix - 1]; - let reference = &references[reference_id]; - if !predicate(&reference_attribute(reference), key_path) { - break; - } - start_ix -= 1; - } - - let mut end_ix = ix; - while end_ix < sorted_reference_ids.len() { - let reference_id = sorted_reference_ids[end_ix]; - let reference = &references[reference_id]; - if !predicate(&reference_attribute(reference), key_path) { - break; - } - end_ix += 1; - } - - sorted_reference_ids[start_ix..end_ix].iter().copied() - } - - fn add_dependency( - (predecessor, successor): (usize, usize), - dependencies: &mut Vec<(usize, usize)>, - dependency_counts: &mut Vec, - ) { - let dependency = (predecessor, successor); - if let Err(i) = dependencies.binary_search(&dependency) { - dependencies.insert(i, dependency); - } - dependency_counts[successor] += 1; - } -} - -impl KeyPathReference { - fn source(&self) -> &KeyPath { - &self.source - } - - fn target(&self) -> &KeyPath { - &self.target - } -} - -impl KeyPath { - fn new(string: &str) -> Self { - Self( - string - .split(".") - .map(|key| Key::Object(key.to_string())) - .collect(), - ) - } - - fn starts_with(&self, other: &[Key]) -> bool { - self.0.starts_with(&other) - } - - fn prefixes(&self) -> impl Iterator { - (1..self.0.len()).map(move |end_ix| &self.0[0..end_ix]) - } -} - -impl PartialEq<[Key]> for KeyPath { - fn eq(&self, other: &[Key]) -> bool { - self.0.eq(other) - } -} - -impl fmt::Debug for KeyPathReference { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "KeyPathReference {{ {} <- {} }}", - self.target, self.source - ) - } -} - -impl fmt::Display for KeyPath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (i, key) in self.0.iter().enumerate() { - match key { - Key::Array(index) => write!(f, "[{}]", index)?, - Key::Object(key) => { - if i > 0 { - ".".fmt(f)?; - } - key.fmt(f)?; - } - } - } - Ok(()) - } -} - fn deep_merge_json(base: &mut Map, extension: Map) { for (key, extension_value) in extension { if let Value::Object(extension_object) = extension_value { @@ -455,69 +119,12 @@ fn deep_merge_json(base: &mut Map, extension: Map) } } -fn find_references(value: &Value, key_path: &mut KeyPath, references: &mut KeyPathReferenceSet) { - match value { - Value::Array(vec) => { - for (ix, value) in vec.iter().enumerate() { - key_path.0.push(Key::Array(ix)); - find_references(value, key_path, references); - key_path.0.pop(); - } - } - Value::Object(map) => { - for (key, value) in map.iter() { - if key == "extends" { - if let Some(source_path) = value.as_str().and_then(|s| s.strip_prefix("$")) { - references.insert(KeyPathReference { - source: KeyPath::new(source_path), - target: key_path.clone(), - }); - } - } else { - key_path.0.push(Key::Object(key.to_string())); - find_references(value, key_path, references); - key_path.0.pop(); - } - } - } - Value::String(string) => { - if let Some(source_path) = string.strip_prefix("$") { - references.insert(KeyPathReference { - source: KeyPath::new(source_path), - target: key_path.clone(), - }); - } - } - _ => {} - } -} - -fn value_at<'a>(object: &'a mut Map, key_path: &KeyPath) -> Option<&'a mut Value> { - let mut key_path = key_path.0.iter(); - if let Some(Key::Object(first_key)) = key_path.next() { - let mut cur_value = object.get_mut(first_key); - for key in key_path { - if let Some(value) = cur_value { - match key { - Key::Array(ix) => cur_value = value.get_mut(ix), - Key::Object(key) => cur_value = value.get_mut(key), - } - } else { - return None; - } - } - cur_value - } else { - None - } -} - #[cfg(test)] mod tests { use super::*; use crate::{test::test_app_state, theme::DEFAULT_THEME_NAME}; + use anyhow::anyhow; use gpui::MutableAppContext; - use rand::{prelude::StdRng, Rng}; #[gpui::test] fn test_bundled_themes(cx: &mut MutableAppContext) { @@ -575,6 +182,7 @@ mod tests { let registry = ThemeRegistry::new(assets, cx.font_cache().clone()); let theme_data = registry.load("light", true).unwrap(); + assert_eq!( theme_data.as_ref(), &serde_json::json!({ @@ -619,120 +227,38 @@ mod tests { ); } - #[test] - fn test_key_path_reference_set_simple() { - let input_references = build_refs(&[ - ("r", "a"), - ("a.b.c", "d"), - ("d.e", "f"), - ("t.u", "v"), - ("v.w", "x"), - ("v.y", "x"), - ("d.h", "i"), - ("v.z", "x"), - ("f.g", "d.h"), - ]); - let expected_references = build_refs(&[ - ("d.h", "i"), - ("f.g", "d.h"), - ("d.e", "f"), - ("a.b.c", "d"), - ("r", "a"), - ("v.w", "x"), - ("v.y", "x"), - ("v.z", "x"), - ("t.u", "v"), - ]) - .collect::>(); - - let mut reference_set = KeyPathReferenceSet::default(); - for reference in input_references { - reference_set.insert(reference); - } - assert_eq!(reference_set.top_sort().unwrap(), expected_references); - } - - #[test] - fn test_key_path_reference_set_with_cycles() { - let input_references = build_refs(&[ - ("x", "a.b"), - ("y", "x.c"), - ("a.b.c", "d.e"), - ("d.e.f", "g.h"), - ("g.h.i", "a"), - ]); - - let mut reference_set = KeyPathReferenceSet::default(); - for reference in input_references { - reference_set.insert(reference); - } + #[gpui::test] + fn test_nested_extension(cx: &mut MutableAppContext) { + let assets = TestAssets(&[( + "themes/theme.toml", + r##" + [a] + text = { extends = "$text.0" } + + [b] + extends = "$a" + text = { extends = "$text.1" } + + [text] + 0 = { color = "red" } + 1 = { color = "blue" } + "##, + )]); + let registry = ThemeRegistry::new(assets, cx.font_cache().clone()); + let theme_data = registry.load("theme", true).unwrap(); assert_eq!( - reference_set.top_sort().unwrap_err(), - &[ - KeyPath::new("a"), - KeyPath::new("a.b.c"), - KeyPath::new("d.e"), - KeyPath::new("d.e.f"), - KeyPath::new("g.h"), - KeyPath::new("g.h.i"), - ] + theme_data + .get("b") + .unwrap() + .get("text") + .unwrap() + .get("color") + .unwrap(), + "blue" ); } - #[gpui::test(iterations = 20)] - async fn test_key_path_reference_set_random(mut rng: StdRng) { - let examples: &[&[_]] = &[ - &[ - ("n.d.h", "i"), - ("f.g", "n.d.h"), - ("n.d.e", "f"), - ("a.b.c", "n.d"), - ("r", "a"), - ("q.q.q", "r.s"), - ("r.t", "q"), - ("x.x", "r.r"), - ("v.w", "x"), - ("v.y", "x"), - ("v.z", "x"), - ("t.u", "v"), - ], - &[ - ("w.x.y.z", "t.u.z"), - ("x", "w.x"), - ("a.b.c1", "x.b1.c"), - ("a.b.c2", "x.b2.c"), - ], - &[ - ("x.y", "m.n.n.o.q"), - ("x.y.z", "m.n.n.o.p"), - ("u.v.w", "x.y.z"), - ("a.b.c.d", "u.v"), - ("a.b.c.d.e", "u.v"), - ("a.b.c.d.f", "u.v"), - ("a.b.c.d.g", "u.v"), - ], - ]; - - for example in examples { - let expected_references = build_refs(example).collect::>(); - let mut input_references = expected_references.clone(); - input_references.sort_by_key(|_| rng.gen_range(0..1000)); - let mut reference_set = KeyPathReferenceSet::default(); - for reference in input_references { - reference_set.insert(reference); - } - assert_eq!(reference_set.top_sort().unwrap(), expected_references); - } - } - - fn build_refs<'a>(rows: &'a [(&str, &str)]) -> impl Iterator + 'a { - rows.iter().map(|(target, source)| KeyPathReference { - target: KeyPath::new(target), - source: KeyPath::new(source), - }) - } - struct TestAssets(&'static [(&'static str, &'static str)]); impl AssetSource for TestAssets { diff --git a/zed/src/user.rs b/zed/src/user.rs index 06aab321934dd1ffb985430157b7cfd02385dbc7..54e84d756ff81229a44dcb5291e12fec1618da27 100644 --- a/zed/src/user.rs +++ b/zed/src/user.rs @@ -111,8 +111,12 @@ impl UserStore { .ok_or_else(|| anyhow!("server responded with no users")) } - pub fn current_user(&self) -> &watch::Receiver>> { - &self.current_user + pub fn current_user(&self) -> Option> { + self.current_user.borrow().clone() + } + + pub fn watch_current_user(&self) -> watch::Receiver>> { + self.current_user.clone() } } diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index 9ce67c2f8a678eb344b91118cde78a5bd0000d3d..ff3666e0de077cc8667716482f2479ba220e0825 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -389,7 +389,7 @@ impl Workspace { ); right_sidebar.add_item("icons/user-16.svg", cx.add_view(|_| ProjectBrowser).into()); - let mut current_user = app_state.user_store.current_user().clone(); + let mut current_user = app_state.user_store.watch_current_user().clone(); let mut connection_status = app_state.rpc.status().clone(); let _observe_current_user = cx.spawn_weak(|this, mut cx| async move { current_user.recv().await; @@ -990,8 +990,6 @@ impl Workspace { let avatar = if let Some(avatar) = self .user_store .current_user() - .borrow() - .as_ref() .and_then(|user| user.avatar.clone()) { Image::new(avatar) diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index c9f1dc0f80dbddb01d37769a2cac35d11d455d30..4e42441eb276ad36e71938946f7c229cc9799e5f 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -151,6 +151,7 @@ message GetUsersResponse { message SendChannelMessage { uint64 channel_id = 1; string body = 2; + Nonce nonce = 3; } message SendChannelMessageResponse { @@ -296,6 +297,11 @@ message Range { uint64 end = 2; } +message Nonce { + uint64 upper_half = 1; + uint64 lower_half = 2; +} + message Channel { uint64 id = 1; string name = 2; @@ -306,4 +312,5 @@ message ChannelMessage { string body = 2; uint64 timestamp = 3; uint64 sender_id = 4; + Nonce nonce = 5; } diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index af9dbf3abcdf070d757635edc77bc4ebc78ed200..b2d4de3bbf501c2ce5c7e28fb0c7f7355171a790 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -248,3 +248,22 @@ impl From for Timestamp { } } } + +impl From for Nonce { + fn from(nonce: u128) -> Self { + let upper_half = (nonce >> 64) as u64; + let lower_half = nonce as u64; + Self { + upper_half, + lower_half, + } + } +} + +impl From for u128 { + fn from(nonce: Nonce) -> Self { + let upper_half = (nonce.upper_half as u128) << 64; + let lower_half = nonce.lower_half as u128; + upper_half | lower_half + } +}