Associate messages with their sender, fetching senders if necessary

Nathan Sobo and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

server/src/db.rs   |   2 
server/src/rpc.rs  |  21 +++--
zed/src/channel.rs | 160 ++++++++++++++++++++++++++++++++++++++---------
zed/src/lib.rs     |   1 
zed/src/main.rs    |   4 
zed/src/test.rs    |   4 
zed/src/user.rs    |  59 +++++++++++++++++
7 files changed, 207 insertions(+), 44 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -162,7 +162,7 @@ impl Db {
                 FROM
                     users, channel_memberships
                 WHERE
-                    users.id IN $1 AND
+                    users.id = ANY ($1) AND
                     channel_memberships.user_id = users.id AND
                     channel_memberships.channel_id IN (
                         SELECT channel_id

server/src/rpc.rs 🔗

@@ -939,6 +939,7 @@ mod tests {
         language::LanguageRegistry,
         rpc::Client,
         settings, test,
+        user::UserStore,
         worktree::Worktree,
     };
     use zrpc::Peer;
@@ -1425,7 +1426,8 @@ mod tests {
         .await
         .unwrap();
 
-        let channels_a = cx_a.add_model(|cx| ChannelList::new(client_a, cx));
+        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
         channels_a
             .condition(&mut cx_a, |list, _| list.available_channels().is_some())
             .await;
@@ -1445,11 +1447,12 @@ mod tests {
         channel_a
             .condition(&cx_a, |channel, _| {
                 channel_messages(channel)
-                    == [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
             })
             .await;
 
-        let channels_b = cx_b.add_model(|cx| ChannelList::new(client_b, cx));
+        let user_store_b = Arc::new(UserStore::new(client_b.clone()));
+        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
         channels_b
             .condition(&mut cx_b, |list, _| list.available_channels().is_some())
             .await;
@@ -1470,7 +1473,7 @@ mod tests {
         channel_b
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
-                    == [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
             })
             .await;
 
@@ -1494,9 +1497,9 @@ mod tests {
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
                     == [
-                        (user_id_b.to_proto(), "hello A, it's B.".to_string()),
-                        (user_id_a.to_proto(), "oh, hi B.".to_string()),
-                        (user_id_a.to_proto(), "sup".to_string()),
+                        ("user_b".to_string(), "hello A, it's B.".to_string()),
+                        ("user_a".to_string(), "oh, hi B.".to_string()),
+                        ("user_a".to_string(), "sup".to_string()),
                     ]
             })
             .await;
@@ -1517,11 +1520,11 @@ mod tests {
             .condition(|state| !state.channels.contains_key(&channel_id))
             .await;
 
-        fn channel_messages(channel: &Channel) -> Vec<(u64, String)> {
+        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
             channel
                 .messages()
                 .cursor::<(), ()>()
-                .map(|m| (m.sender_id, m.body.clone()))
+                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
                 .collect()
         }
     }

zed/src/channel.rs 🔗

@@ -1,5 +1,6 @@
 use crate::{
     rpc::{self, Client},
+    user::{User, UserStore},
     util::TryFutureExt,
 };
 use anyhow::{anyhow, Context, Result};
@@ -9,7 +10,7 @@ use gpui::{
 };
 use postage::prelude::Stream;
 use std::{
-    collections::{hash_map, HashMap},
+    collections::{hash_map, HashMap, HashSet},
     ops::Range,
     sync::Arc,
 };
@@ -22,6 +23,7 @@ pub struct ChannelList {
     available_channels: Option<Vec<ChannelDetails>>,
     channels: HashMap<u64, WeakModelHandle<Channel>>,
     rpc: Arc<Client>,
+    user_store: Arc<UserStore>,
     _task: Task<Option<()>>,
 }
 
@@ -36,6 +38,7 @@ pub struct Channel {
     messages: SumTree<ChannelMessage>,
     pending_messages: Vec<PendingChannelMessage>,
     next_local_message_id: u64,
+    user_store: Arc<UserStore>,
     rpc: Arc<Client>,
     _subscription: rpc::Subscription,
 }
@@ -43,8 +46,8 @@ pub struct Channel {
 #[derive(Clone, Debug, PartialEq)]
 pub struct ChannelMessage {
     pub id: u64,
-    pub sender_id: u64,
     pub body: String,
+    pub sender: Arc<User>,
 }
 
 pub struct PendingChannelMessage {
@@ -76,7 +79,11 @@ impl Entity for ChannelList {
 }
 
 impl ChannelList {
-    pub fn new(rpc: Arc<rpc::Client>, cx: &mut ModelContext<Self>) -> Self {
+    pub fn new(
+        user_store: Arc<UserStore>,
+        rpc: Arc<rpc::Client>,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
         let _task = cx.spawn(|this, mut cx| {
             let rpc = rpc.clone();
             async move {
@@ -114,6 +121,7 @@ impl ChannelList {
         Self {
             available_channels: None,
             channels: Default::default(),
+            user_store,
             rpc,
             _task,
         }
@@ -136,8 +144,10 @@ impl ChannelList {
                     .as_ref()
                     .and_then(|channels| channels.iter().find(|details| details.id == id))
                 {
+                    let user_store = self.user_store.clone();
                     let rpc = self.rpc.clone();
-                    let channel = cx.add_model(|cx| Channel::new(details.clone(), rpc, cx));
+                    let channel =
+                        cx.add_model(|cx| Channel::new(details.clone(), user_store, rpc, cx));
                     entry.insert(channel.downgrade());
                     Some(channel)
                 } else {
@@ -165,34 +175,58 @@ impl Entity for Channel {
 }
 
 impl Channel {
-    pub fn new(details: ChannelDetails, rpc: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
+    pub fn new(
+        details: ChannelDetails,
+        user_store: Arc<UserStore>,
+        rpc: Arc<Client>,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
         let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent);
 
         {
+            let user_store = user_store.clone();
             let rpc = rpc.clone();
             let channel_id = details.id;
-            cx.spawn(|channel, mut cx| async move {
-                match rpc.request(proto::JoinChannel { channel_id }).await {
-                    Ok(response) => channel.update(&mut cx, |channel, cx| {
+            cx.spawn(|channel, mut cx| {
+                async move {
+                    let response = rpc.request(proto::JoinChannel { channel_id }).await?;
+
+                    let unique_user_ids = response
+                        .messages
+                        .iter()
+                        .map(|m| m.sender_id)
+                        .collect::<HashSet<_>>()
+                        .into_iter()
+                        .collect();
+                    user_store.load_users(unique_user_ids).await?;
+
+                    let mut messages = Vec::with_capacity(response.messages.len());
+                    for message in response.messages {
+                        messages.push(ChannelMessage::from_proto(message, &user_store).await?);
+                    }
+
+                    channel.update(&mut cx, |channel, cx| {
                         let old_count = channel.messages.summary().count.0;
-                        let new_count = response.messages.len();
+                        let new_count = messages.len();
+
                         channel.messages = SumTree::new();
-                        channel
-                            .messages
-                            .extend(response.messages.into_iter().map(Into::into), &());
+                        channel.messages.extend(messages, &());
                         cx.emit(ChannelEvent::Message {
                             old_range: 0..old_count,
                             new_count,
                         });
-                    }),
-                    Err(error) => log::error!("error joining channel: {}", error),
+                    });
+
+                    Ok(())
                 }
+                .log_err()
             })
             .detach();
         }
 
         Self {
             details,
+            user_store,
             rpc,
             messages: Default::default(),
             pending_messages: Default::default(),
@@ -210,11 +244,14 @@ impl Channel {
             local_id,
             body: body.clone(),
         });
+        let user_store = self.user_store.clone();
         let rpc = self.rpc.clone();
         cx.spawn(|this, mut cx| {
             async move {
                 let request = rpc.request(proto::SendChannelMessage { channel_id, body });
                 let response = request.await?;
+                let sender = user_store.get_user(current_user_id).await?;
+
                 this.update(&mut cx, |this, cx| {
                     if let Ok(i) = this
                         .pending_messages
@@ -224,8 +261,8 @@ impl Channel {
                         this.insert_message(
                             ChannelMessage {
                                 id: response.message_id,
-                                sender_id: current_user_id,
                                 body,
+                                sender,
                             },
                             cx,
                         );
@@ -267,11 +304,21 @@ impl Channel {
         _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
+        let user_store = self.user_store.clone();
         let message = message
             .payload
             .message
             .ok_or_else(|| anyhow!("empty message"))?;
-        self.insert_message(message.into(), cx);
+
+        cx.spawn(|this, mut cx| {
+            async move {
+                let message = ChannelMessage::from_proto(message, &user_store).await?;
+                this.update(&mut cx, |this, cx| this.insert_message(message, cx));
+                Ok(())
+            }
+            .log_err()
+        })
+        .detach();
         Ok(())
     }
 
@@ -307,13 +354,17 @@ impl From<proto::Channel> for ChannelDetails {
     }
 }
 
-impl From<proto::ChannelMessage> for ChannelMessage {
-    fn from(message: proto::ChannelMessage) -> Self {
-        ChannelMessage {
+impl ChannelMessage {
+    pub async fn from_proto(
+        message: proto::ChannelMessage,
+        user_store: &UserStore,
+    ) -> Result<Self> {
+        let sender = user_store.get_user(message.sender_id).await?;
+        Ok(ChannelMessage {
             id: message.id,
-            sender_id: message.sender_id,
             body: message.body,
-        }
+            sender,
+        })
     }
 }
 
@@ -368,15 +419,16 @@ mod tests {
         let user_id = 5;
         let client = Client::new();
         let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+        let user_store = Arc::new(UserStore::new(client.clone()));
 
-        let channel_list = cx.add_model(|cx| ChannelList::new(client.clone(), cx));
+        let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
         channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
 
         // Get the available channels.
-        let message = server.receive::<proto::GetChannels>().await;
+        let get_channels = server.receive::<proto::GetChannels>().await;
         server
             .respond(
-                message.receipt(),
+                get_channels.receipt(),
                 proto::GetChannelsResponse {
                     channels: vec![proto::Channel {
                         id: 5,
@@ -404,10 +456,10 @@ mod tests {
             })
             .unwrap();
         channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
-        let message = server.receive::<proto::JoinChannel>().await;
+        let join_channel = server.receive::<proto::JoinChannel>().await;
         server
             .respond(
-                message.receipt(),
+                join_channel.receipt(),
                 proto::JoinChannelResponse {
                     messages: vec![
                         proto::ChannelMessage {
@@ -420,12 +472,36 @@ mod tests {
                             id: 11,
                             body: "b".into(),
                             timestamp: 1001,
-                            sender_id: 5,
+                            sender_id: 6,
+                        },
+                    ],
+                },
+            )
+            .await;
+        // Client requests all users for the received messages
+        let mut get_users = server.receive::<proto::GetUsers>().await;
+        get_users.payload.user_ids.sort();
+        assert_eq!(get_users.payload.user_ids, vec![5, 6]);
+        server
+            .respond(
+                get_users.receipt(),
+                proto::GetUsersResponse {
+                    users: vec![
+                        proto::User {
+                            id: 5,
+                            github_login: "nathansobo".into(),
+                            avatar_url: "http://avatar.com/nathansobo".into(),
+                        },
+                        proto::User {
+                            id: 6,
+                            github_login: "maxbrunsfeld".into(),
+                            avatar_url: "http://avatar.com/maxbrunsfeld".into(),
                         },
                     ],
                 },
             )
             .await;
+
         assert_eq!(
             channel.next_event(&cx).await,
             ChannelEvent::Message {
@@ -437,9 +513,12 @@ mod tests {
             assert_eq!(
                 channel
                     .messages_in_range(0..2)
-                    .map(|message| &message.body)
+                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
                     .collect::<Vec<_>>(),
-                &["a", "b"]
+                &[
+                    ("nathansobo".into(), "a".into()),
+                    ("maxbrunsfeld".into(), "b".into())
+                ]
             );
         });
 
@@ -451,10 +530,27 @@ mod tests {
                     id: 12,
                     body: "c".into(),
                     timestamp: 1002,
-                    sender_id: 5,
+                    sender_id: 7,
                 }),
             })
             .await;
+
+        // Client requests user for message since they haven't seen them yet
+        let get_users = server.receive::<proto::GetUsers>().await;
+        assert_eq!(get_users.payload.user_ids, vec![7]);
+        server
+            .respond(
+                get_users.receipt(),
+                proto::GetUsersResponse {
+                    users: vec![proto::User {
+                        id: 7,
+                        github_login: "as-cii".into(),
+                        avatar_url: "http://avatar.com/as-cii".into(),
+                    }],
+                },
+            )
+            .await;
+
         assert_eq!(
             channel.next_event(&cx).await,
             ChannelEvent::Message {
@@ -466,9 +562,9 @@ mod tests {
             assert_eq!(
                 channel
                     .messages_in_range(2..3)
-                    .map(|message| &message.body)
+                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
                     .collect::<Vec<_>>(),
-                &["c"]
+                &[("as-cii".into(), "c".into())]
             )
         })
     }

zed/src/lib.rs 🔗

@@ -15,6 +15,7 @@ pub mod test;
 pub mod theme;
 pub mod theme_selector;
 mod time;
+pub mod user;
 mod util;
 pub mod workspace;
 pub mod worktree;

zed/src/main.rs 🔗

@@ -12,6 +12,7 @@ use zed::{
     chat_panel, editor, file_finder,
     fs::RealFs,
     language, menus, rpc, settings, theme_selector,
+    user::UserStore,
     workspace::{self, OpenParams, OpenPaths},
     AppState,
 };
@@ -29,12 +30,13 @@ fn main() {
 
     app.run(move |cx| {
         let rpc = rpc::Client::new();
+        let user_store = Arc::new(UserStore::new(rpc.clone()));
         let app_state = Arc::new(AppState {
             languages: languages.clone(),
             settings_tx: Arc::new(Mutex::new(settings_tx)),
             settings,
             themes,
-            channel_list: cx.add_model(|cx| ChannelList::new(rpc.clone(), cx)),
+            channel_list: cx.add_model(|cx| ChannelList::new(user_store, rpc.clone(), cx)),
             rpc,
             fs: Arc::new(RealFs),
         });

zed/src/test.rs 🔗

@@ -5,6 +5,7 @@ use crate::{
     rpc,
     settings::{self, ThemeRegistry},
     time::ReplicaId,
+    user::UserStore,
     AppState, Settings,
 };
 use gpui::{AppContext, Entity, ModelHandle, MutableAppContext};
@@ -164,12 +165,13 @@ pub fn build_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
     let languages = Arc::new(LanguageRegistry::new());
     let themes = ThemeRegistry::new(());
     let rpc = rpc::Client::new();
+    let user_store = Arc::new(UserStore::new(rpc.clone()));
     Arc::new(AppState {
         settings_tx: Arc::new(Mutex::new(settings_tx)),
         settings,
         themes,
         languages: languages.clone(),
-        channel_list: cx.add_model(|cx| ChannelList::new(rpc.clone(), cx)),
+        channel_list: cx.add_model(|cx| ChannelList::new(user_store, rpc.clone(), cx)),
         rpc,
         fs: Arc::new(RealFs),
     })

zed/src/user.rs 🔗

@@ -0,0 +1,59 @@
+use crate::rpc::Client;
+use anyhow::{anyhow, Result};
+use parking_lot::Mutex;
+use std::{collections::HashMap, sync::Arc};
+use zrpc::proto;
+
+pub use proto::User;
+
+pub struct UserStore {
+    users: Mutex<HashMap<u64, Arc<User>>>,
+    rpc: Arc<Client>,
+}
+
+impl UserStore {
+    pub fn new(rpc: Arc<Client>) -> Self {
+        Self {
+            users: Default::default(),
+            rpc,
+        }
+    }
+
+    pub async fn load_users(&self, mut user_ids: Vec<u64>) -> Result<()> {
+        {
+            let users = self.users.lock();
+            user_ids.retain(|id| !users.contains_key(id));
+        }
+
+        if !user_ids.is_empty() {
+            let response = self.rpc.request(proto::GetUsers { user_ids }).await?;
+            let mut users = self.users.lock();
+            for user in response.users {
+                users.insert(user.id, Arc::new(user));
+            }
+        }
+
+        Ok(())
+    }
+
+    pub async fn get_user(&self, user_id: u64) -> Result<Arc<User>> {
+        if let Some(user) = self.users.lock().get(&user_id).cloned() {
+            return Ok(user);
+        }
+
+        let response = self
+            .rpc
+            .request(proto::GetUsers {
+                user_ids: vec![user_id],
+            })
+            .await?;
+
+        if let Some(user) = response.users.into_iter().next() {
+            let user = Arc::new(user);
+            self.users.lock().insert(user_id, user.clone());
+            Ok(user)
+        } else {
+            Err(anyhow!("server responded with no users"))
+        }
+    }
+}