Implement initial RPC endpoints for chat

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock           |   2 
server/Cargo.toml    |   3 
server/src/auth.rs   |   6 -
server/src/db.rs     | 125 ++++++++++++++++++++++-----------
server/src/rpc.rs    | 165 +++++++++++++++++++++++++++++++++++++++------
server/src/tests.rs  |  30 ++++++-
zrpc/proto/zed.proto |   8 +-
7 files changed, 258 insertions(+), 81 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4699,6 +4699,7 @@ dependencies = [
  "sqlx-rt 0.5.5",
  "stringprep",
  "thiserror",
+ "time 0.2.25",
  "url",
  "webpki",
  "webpki-roots",
@@ -5866,6 +5867,7 @@ dependencies = [
  "surf",
  "tide",
  "tide-compress",
+ "time 0.2.25",
  "toml 0.5.8",
  "zed",
  "zrpc",

server/Cargo.toml 🔗

@@ -31,6 +31,7 @@ sha-1 = "0.9"
 surf = "2.2.0"
 tide = "0.16.0"
 tide-compress = "0.9.0"
+time = "0.2"
 toml = "0.5.8"
 zrpc = { path = "../zrpc" }
 
@@ -41,7 +42,7 @@ default-features = false
 
 [dependencies.sqlx]
 version = "0.5.2"
-features = ["runtime-async-std-rustls", "postgres"]
+features = ["runtime-async-std-rustls", "postgres", "time"]
 
 [dev-dependencies]
 gpui = { path = "../gpui" }

server/src/auth.rs 🔗

@@ -257,7 +257,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
     // When signing in from the native app, generate a new access token for the current user. Return
     // a redirect so that the user's browser sends this access token to the locally-running app.
     if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
-        let access_token = create_access_token(request.db(), user.id()).await?;
+        let access_token = create_access_token(request.db(), user.id).await?;
         let native_app_public_key =
             zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
                 .context("failed to parse app public key")?;
@@ -267,9 +267,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
 
         return Ok(tide::Redirect::new(&format!(
             "http://127.0.0.1:{}?user_id={}&access_token={}",
-            app_sign_in_params.native_app_port,
-            user.id().0,
-            encrypted_access_token,
+            app_sign_in_params.native_app_port, user.id.0, encrypted_access_token,
         ))
         .into());
     }

server/src/db.rs 🔗

@@ -1,5 +1,6 @@
 use serde::Serialize;
 use sqlx::{FromRow, Result};
+use time::OffsetDateTime;
 
 pub use async_sqlx_session::PostgresSessionStore as SessionStore;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
@@ -8,14 +9,14 @@ pub struct Db(pub sqlx::PgPool);
 
 #[derive(Debug, FromRow, Serialize)]
 pub struct User {
-    id: i32,
+    pub id: UserId,
     pub github_login: String,
     pub admin: bool,
 }
 
 #[derive(Debug, FromRow, Serialize)]
 pub struct Signup {
-    id: i32,
+    pub id: SignupId,
     pub github_login: String,
     pub email_address: String,
     pub about: String,
@@ -23,33 +24,18 @@ pub struct Signup {
 
 #[derive(Debug, FromRow, Serialize)]
 pub struct Channel {
-    id: i32,
+    pub id: ChannelId,
     pub name: String,
 }
 
 #[derive(Debug, FromRow)]
 pub struct ChannelMessage {
-    id: i32,
-    sender_id: i32,
-    body: String,
-    sent_at: i64,
+    pub id: MessageId,
+    pub sender_id: UserId,
+    pub body: String,
+    pub sent_at: OffsetDateTime,
 }
 
-#[derive(Clone, Copy)]
-pub struct UserId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct OrgId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct ChannelId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct SignupId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct MessageId(pub i32);
-
 impl Db {
     // signups
 
@@ -108,6 +94,33 @@ impl Db {
         sqlx::query_as(query).fetch_all(&self.0).await
     }
 
+    pub async fn get_users_by_ids(
+        &self,
+        requester_id: UserId,
+        ids: impl Iterator<Item = UserId>,
+    ) -> Result<Vec<User>> {
+        // Only return users that are in a common channel with the requesting user.
+        let query = "
+            SELECT users.*
+            FROM
+                users, channel_memberships
+            WHERE
+                users.id IN $1 AND
+                channel_memberships.user_id = users.id AND
+                channel_memberships.channel_id IN (
+                    SELECT channel_id
+                    FROM channel_memberships
+                    WHERE channel_memberships.user_id = $2
+                )
+        ";
+
+        sqlx::query_as(query)
+            .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
+            .bind(requester_id)
+            .fetch_all(&self.0)
+            .await
+    }
+
     pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
         let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
         sqlx::query_as(query)
@@ -147,7 +160,7 @@ impl Db {
             VALUES ($1, $2)
         ";
         sqlx::query(query)
-            .bind(user_id.0 as i32)
+            .bind(user_id.0)
             .bind(access_token_hash)
             .execute(&self.0)
             .await
@@ -156,8 +169,8 @@ impl Db {
 
     pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
         let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
-        sqlx::query_scalar::<_, String>(query)
-            .bind(user_id.0 as i32)
+        sqlx::query_scalar(query)
+            .bind(user_id.0)
             .fetch_all(&self.0)
             .await
     }
@@ -180,14 +193,20 @@ impl Db {
     }
 
     #[cfg(test)]
-    pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
+    pub async fn add_org_member(
+        &self,
+        org_id: OrgId,
+        user_id: UserId,
+        is_admin: bool,
+    ) -> Result<()> {
         let query = "
-            INSERT INTO org_memberships (org_id, user_id)
-            VALUES ($1, $2)
+            INSERT INTO org_memberships (org_id, user_id, admin)
+            VALUES ($1, $2, $3)
         ";
         sqlx::query(query)
             .bind(org_id.0)
             .bind(user_id.0)
+            .bind(is_admin)
             .execute(&self.0)
             .await
             .map(drop)
@@ -272,16 +291,18 @@ impl Db {
         channel_id: ChannelId,
         sender_id: UserId,
         body: &str,
+        timestamp: OffsetDateTime,
     ) -> Result<MessageId> {
         let query = "
             INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
-            VALUES ($1, $2, $3, NOW()::timestamp)
+            VALUES ($1, $2, $3, $4)
             RETURNING id
         ";
         sqlx::query_scalar(query)
             .bind(channel_id.0)
             .bind(sender_id.0)
             .bind(body)
+            .bind(timestamp)
             .fetch_one(&self.0)
             .await
             .map(MessageId)
@@ -292,12 +313,15 @@ impl Db {
         channel_id: ChannelId,
         count: usize,
     ) -> Result<Vec<ChannelMessage>> {
-        let query = "
-            SELECT id, sender_id, body, sent_at
-            FROM channel_messages
-            WHERE channel_id = $1
+        let query = r#"
+            SELECT
+                id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+            FROM
+                channel_messages
+            WHERE
+                channel_id = $1
             LIMIT $2
-        ";
+        "#;
         sqlx::query_as(query)
             .bind(channel_id.0)
             .bind(count as i64)
@@ -314,14 +338,29 @@ impl std::ops::Deref for Db {
     }
 }
 
-impl Channel {
-    pub fn id(&self) -> ChannelId {
-        ChannelId(self.id)
-    }
+macro_rules! id_type {
+    ($name:ident) => {
+        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
+        #[sqlx(transparent)]
+        #[serde(transparent)]
+        pub struct $name(pub i32);
+
+        impl $name {
+            #[allow(unused)]
+            pub fn from_proto(value: u64) -> Self {
+                Self(value as i32)
+            }
+
+            #[allow(unused)]
+            pub fn to_proto(&self) -> u64 {
+                self.0 as u64
+            }
+        }
+    };
 }
 
-impl User {
-    pub fn id(&self) -> UserId {
-        UserId(self.id)
-    }
-}
+id_type!(UserId);
+id_type!(OrgId);
+id_type!(ChannelId);
+id_type!(SignupId);
+id_type!(MessageId);

server/src/rpc.rs 🔗

@@ -23,6 +23,7 @@ use tide::{
     http::headers::{HeaderName, CONNECTION, UPGRADE},
     Request, Response,
 };
+use time::OffsetDateTime;
 use zrpc::{
     auth::random_token,
     proto::{self, EnvelopedMessage},
@@ -33,17 +34,19 @@ type ReplicaId = u16;
 
 #[derive(Default)]
 pub struct State {
-    connections: HashMap<ConnectionId, ConnectionState>,
-    pub worktrees: HashMap<u64, WorktreeState>,
+    connections: HashMap<ConnectionId, Connection>,
+    pub worktrees: HashMap<u64, Worktree>,
+    channels: HashMap<ChannelId, Channel>,
     next_worktree_id: u64,
 }
 
-struct ConnectionState {
+struct Connection {
     user_id: UserId,
     worktrees: HashSet<u64>,
+    channels: HashSet<ChannelId>,
 }
 
-pub struct WorktreeState {
+pub struct Worktree {
     host_connection_id: Option<ConnectionId>,
     guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
     active_replica_ids: HashSet<ReplicaId>,
@@ -52,7 +55,12 @@ pub struct WorktreeState {
     entries: HashMap<u64, proto::Entry>,
 }
 
-impl WorktreeState {
+#[derive(Default)]
+struct Channel {
+    connection_ids: HashSet<ConnectionId>,
+}
+
+impl Worktree {
     pub fn connection_ids(&self) -> Vec<ConnectionId> {
         self.guest_connection_ids
             .keys()
@@ -68,14 +76,21 @@ impl WorktreeState {
     }
 }
 
+impl Channel {
+    fn connection_ids(&self) -> Vec<ConnectionId> {
+        self.connection_ids.iter().copied().collect()
+    }
+}
+
 impl State {
     // Add a new connection associated with a given user.
     pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
         self.connections.insert(
             connection_id,
-            ConnectionState {
+            Connection {
                 user_id,
                 worktrees: Default::default(),
+                channels: Default::default(),
             },
         );
     }
@@ -83,8 +98,13 @@ impl State {
     // Remove the given connection and its association with any worktrees.
     pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
         let mut worktree_ids = Vec::new();
-        if let Some(connection_state) = self.connections.remove(&connection_id) {
-            for worktree_id in connection_state.worktrees {
+        if let Some(connection) = self.connections.remove(&connection_id) {
+            for channel_id in connection.channels {
+                if let Some(channel) = self.channels.get_mut(&channel_id) {
+                    channel.connection_ids.remove(&connection_id);
+                }
+            }
+            for worktree_id in connection.worktrees {
                 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
                     if worktree.host_connection_id == Some(connection_id) {
                         worktree_ids.push(worktree_id);
@@ -100,28 +120,39 @@ impl State {
         worktree_ids
     }
 
+    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+        if let Some(connection) = self.connections.get_mut(&connection_id) {
+            connection.channels.insert(channel_id);
+            self.channels
+                .entry(channel_id)
+                .or_default()
+                .connection_ids
+                .insert(connection_id);
+        }
+    }
+
     // Add the given connection as a guest of the given worktree
     pub fn join_worktree(
         &mut self,
         connection_id: ConnectionId,
         worktree_id: u64,
         access_token: &str,
-    ) -> Option<(ReplicaId, &WorktreeState)> {
-        if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
-            if access_token == worktree_state.access_token {
-                if let Some(connection_state) = self.connections.get_mut(&connection_id) {
-                    connection_state.worktrees.insert(worktree_id);
+    ) -> Option<(ReplicaId, &Worktree)> {
+        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+            if access_token == worktree.access_token {
+                if let Some(connection) = self.connections.get_mut(&connection_id) {
+                    connection.worktrees.insert(worktree_id);
                 }
 
                 let mut replica_id = 1;
-                while worktree_state.active_replica_ids.contains(&replica_id) {
+                while worktree.active_replica_ids.contains(&replica_id) {
                     replica_id += 1;
                 }
-                worktree_state.active_replica_ids.insert(replica_id);
-                worktree_state
+                worktree.active_replica_ids.insert(replica_id);
+                worktree
                     .guest_connection_ids
                     .insert(connection_id, replica_id);
-                Some((replica_id, worktree_state))
+                Some((replica_id, worktree))
             } else {
                 None
             }
@@ -142,7 +173,7 @@ impl State {
         &self,
         worktree_id: u64,
         connection_id: ConnectionId,
-    ) -> tide::Result<&WorktreeState> {
+    ) -> tide::Result<&Worktree> {
         let worktree = self
             .worktrees
             .get(&worktree_id)
@@ -165,7 +196,7 @@ impl State {
         &mut self,
         worktree_id: u64,
         connection_id: ConnectionId,
-    ) -> tide::Result<&mut WorktreeState> {
+    ) -> tide::Result<&mut Worktree> {
         let worktree = self
             .worktrees
             .get_mut(&worktree_id)
@@ -263,7 +294,9 @@ pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer
     on_message(router, rpc, state, buffer_saved);
     on_message(router, rpc, state, save_buffer);
     on_message(router, rpc, state, get_channels);
+    on_message(router, rpc, state, get_users);
     on_message(router, rpc, state, join_channel);
+    on_message(router, rpc, state, send_channel_message);
 }
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -373,7 +406,7 @@ async fn share_worktree(
         .collect();
     state.worktrees.insert(
         worktree_id,
-        WorktreeState {
+        Worktree {
             host_connection_id: Some(request.sender_id),
             guest_connection_ids: Default::default(),
             active_replica_ids: Default::default(),
@@ -627,7 +660,7 @@ async fn get_channels(
             channels: channels
                 .into_iter()
                 .map(|chan| proto::Channel {
-                    id: chan.id().0 as u64,
+                    id: chan.id.to_proto(),
                     name: chan.name,
                 })
                 .collect(),
@@ -637,6 +670,34 @@ async fn get_channels(
     Ok(())
 }
 
+async fn get_users(
+    request: TypedEnvelope<proto::GetUsers>,
+    rpc: &Arc<Peer>,
+    state: &Arc<AppState>,
+) -> tide::Result<()> {
+    let user_id = state
+        .rpc
+        .read()
+        .await
+        .user_id_for_connection(request.sender_id)?;
+    let receipt = request.receipt();
+    let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+    let users = state
+        .db
+        .get_users_by_ids(user_id, user_ids)
+        .await?
+        .into_iter()
+        .map(|user| proto::User {
+            id: user.id.to_proto(),
+            github_login: user.github_login,
+            avatar_url: String::new(),
+        })
+        .collect();
+    rpc.respond(receipt, proto::GetUsersResponse { users })
+        .await?;
+    Ok(())
+}
+
 async fn join_channel(
     request: TypedEnvelope<proto::JoinChannel>,
     rpc: &Arc<Peer>,
@@ -647,14 +708,74 @@ async fn join_channel(
         .read()
         .await
         .user_id_for_connection(request.sender_id)?;
+    let channel_id = ChannelId::from_proto(request.payload.channel_id);
     if !state
         .db
-        .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32))
+        .can_user_access_channel(user_id, channel_id)
         .await?
     {
         Err(anyhow!("access denied"))?;
     }
 
+    state
+        .rpc
+        .write()
+        .await
+        .join_channel(request.sender_id, channel_id);
+    let messages = state
+        .db
+        .get_recent_channel_messages(channel_id, 50)
+        .await?
+        .into_iter()
+        .map(|msg| proto::ChannelMessage {
+            id: msg.id.to_proto(),
+            body: msg.body,
+            timestamp: msg.sent_at.unix_timestamp() as u64,
+            sender_id: msg.sender_id.to_proto(),
+        })
+        .collect();
+    rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
+        .await?;
+    Ok(())
+}
+
+async fn send_channel_message(
+    request: TypedEnvelope<proto::SendChannelMessage>,
+    peer: &Arc<Peer>,
+    app: &Arc<AppState>,
+) -> tide::Result<()> {
+    let channel_id = ChannelId::from_proto(request.payload.channel_id);
+    let user_id;
+    let connection_ids;
+    {
+        let state = app.rpc.read().await;
+        user_id = state.user_id_for_connection(request.sender_id)?;
+        if let Some(channel) = state.channels.get(&channel_id) {
+            connection_ids = channel.connection_ids();
+        } else {
+            return Ok(());
+        }
+    }
+
+    let timestamp = OffsetDateTime::now_utc();
+    let message_id = app
+        .db
+        .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+        .await?;
+    let message = proto::ChannelMessageSent {
+        channel_id: channel_id.to_proto(),
+        message: Some(proto::ChannelMessage {
+            sender_id: user_id.to_proto(),
+            id: message_id.to_proto(),
+            body: request.payload.body,
+            timestamp: timestamp.unix_timestamp() as u64,
+        }),
+    };
+    broadcast(request.sender_id, connection_ids, |conn_id| {
+        peer.send(conn_id, message.clone())
+    })
+    .await?;
+
     Ok(())
 }
 

server/src/tests.rs 🔗

@@ -11,9 +11,10 @@ use rand::prelude::*;
 use serde_json::json;
 use sqlx::{
     migrate::{MigrateDatabase, Migrator},
+    types::time::OffsetDateTime,
     Executor as _, Postgres,
 };
-use std::{path::Path, sync::Arc};
+use std::{path::Path, sync::Arc, time::SystemTime};
 use zed::{
     editor::Editor,
     fs::{FakeFs, Fs as _},
@@ -485,10 +486,15 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
     let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await;
 
-    // Create a channel that includes these 2 users and 1 other user.
+    // Create an org that includes these 2 users and 1 other user.
     let db = &server.app_state.db;
     let user_id_c = db.create_user("user_c", false).await.unwrap();
     let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+    db.add_org_member(org_id, user_id_a, false).await.unwrap();
+    db.add_org_member(org_id, user_id_b, false).await.unwrap();
+    db.add_org_member(org_id, user_id_c, false).await.unwrap();
+
+    // Create a channel that includes all the users.
     let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
     db.add_channel_member(channel_id, user_id_a, false)
         .await
@@ -499,11 +505,21 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     db.add_channel_member(channel_id, user_id_c, false)
         .await
         .unwrap();
-    db.create_channel_message(channel_id, user_id_c, "first message!")
-        .await
-        .unwrap();
-
-    // let chatroom_a = ChatRoom::
+    db.create_channel_message(
+        channel_id,
+        user_id_c,
+        "first message!",
+        OffsetDateTime::now_utc(),
+    )
+    .await
+    .unwrap();
+    assert_eq!(
+        db.get_recent_channel_messages(channel_id, 50)
+            .await
+            .unwrap()[0]
+            .body,
+        "first message!"
+    );
 }
 
 struct TestServer {

zrpc/proto/zed.proto 🔗

@@ -24,10 +24,10 @@ message Envelope {
         RemovePeer remove_peer = 19;
         GetChannels get_channels = 20;
         GetChannelsResponse get_channels_response = 21;
-        JoinChannel join_channel = 22;
-        JoinChannelResponse join_channel_response = 23;
-        GetUsers get_users = 24;
-        GetUsersResponse get_users_response = 25;
+        GetUsers get_users = 22;
+        GetUsersResponse get_users_response = 23;
+        JoinChannel join_channel = 24;
+        JoinChannelResponse join_channel_response = 25;
         SendChannelMessage send_channel_message = 26;
         ChannelMessageSent channel_message_sent = 27;
     }