Detailed changes
@@ -4699,6 +4699,7 @@ dependencies = [
"sqlx-rt 0.5.5",
"stringprep",
"thiserror",
+ "time 0.2.25",
"url",
"webpki",
"webpki-roots",
@@ -5866,6 +5867,7 @@ dependencies = [
"surf",
"tide",
"tide-compress",
+ "time 0.2.25",
"toml 0.5.8",
"zed",
"zrpc",
@@ -31,6 +31,7 @@ sha-1 = "0.9"
surf = "2.2.0"
tide = "0.16.0"
tide-compress = "0.9.0"
+time = "0.2"
toml = "0.5.8"
zrpc = { path = "../zrpc" }
@@ -41,7 +42,7 @@ default-features = false
[dependencies.sqlx]
version = "0.5.2"
-features = ["runtime-async-std-rustls", "postgres"]
+features = ["runtime-async-std-rustls", "postgres", "time"]
[dev-dependencies]
gpui = { path = "../gpui" }
@@ -257,7 +257,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
- let access_token = create_access_token(request.db(), user.id()).await?;
+ let access_token = create_access_token(request.db(), user.id).await?;
let native_app_public_key =
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
.context("failed to parse app public key")?;
@@ -267,9 +267,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
- app_sign_in_params.native_app_port,
- user.id().0,
- encrypted_access_token,
+ app_sign_in_params.native_app_port, user.id.0, encrypted_access_token,
))
.into());
}
@@ -1,5 +1,6 @@
use serde::Serialize;
use sqlx::{FromRow, Result};
+use time::OffsetDateTime;
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
@@ -8,14 +9,14 @@ pub struct Db(pub sqlx::PgPool);
#[derive(Debug, FromRow, Serialize)]
pub struct User {
- id: i32,
+ pub id: UserId,
pub github_login: String,
pub admin: bool,
}
#[derive(Debug, FromRow, Serialize)]
pub struct Signup {
- id: i32,
+ pub id: SignupId,
pub github_login: String,
pub email_address: String,
pub about: String,
@@ -23,33 +24,18 @@ pub struct Signup {
#[derive(Debug, FromRow, Serialize)]
pub struct Channel {
- id: i32,
+ pub id: ChannelId,
pub name: String,
}
#[derive(Debug, FromRow)]
pub struct ChannelMessage {
- id: i32,
- sender_id: i32,
- body: String,
- sent_at: i64,
+ pub id: MessageId,
+ pub sender_id: UserId,
+ pub body: String,
+ pub sent_at: OffsetDateTime,
}
-#[derive(Clone, Copy)]
-pub struct UserId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct OrgId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct ChannelId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct SignupId(pub i32);
-
-#[derive(Clone, Copy)]
-pub struct MessageId(pub i32);
-
impl Db {
// signups
@@ -108,6 +94,33 @@ impl Db {
sqlx::query_as(query).fetch_all(&self.0).await
}
+ pub async fn get_users_by_ids(
+ &self,
+ requester_id: UserId,
+ ids: impl Iterator<Item = UserId>,
+ ) -> Result<Vec<User>> {
+ // Only return users that are in a common channel with the requesting user.
+ let query = "
+ SELECT users.*
+ FROM
+ users, channel_memberships
+ WHERE
+ users.id IN $1 AND
+ channel_memberships.user_id = users.id AND
+ channel_memberships.channel_id IN (
+ SELECT channel_id
+ FROM channel_memberships
+ WHERE channel_memberships.user_id = $2
+ )
+ ";
+
+ sqlx::query_as(query)
+ .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
+ .bind(requester_id)
+ .fetch_all(&self.0)
+ .await
+ }
+
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
@@ -147,7 +160,7 @@ impl Db {
VALUES ($1, $2)
";
sqlx::query(query)
- .bind(user_id.0 as i32)
+ .bind(user_id.0)
.bind(access_token_hash)
.execute(&self.0)
.await
@@ -156,8 +169,8 @@ impl Db {
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
- sqlx::query_scalar::<_, String>(query)
- .bind(user_id.0 as i32)
+ sqlx::query_scalar(query)
+ .bind(user_id.0)
.fetch_all(&self.0)
.await
}
@@ -180,14 +193,20 @@ impl Db {
}
#[cfg(test)]
- pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
+ pub async fn add_org_member(
+ &self,
+ org_id: OrgId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
let query = "
- INSERT INTO org_memberships (org_id, user_id)
- VALUES ($1, $2)
+ INSERT INTO org_memberships (org_id, user_id, admin)
+ VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
+ .bind(is_admin)
.execute(&self.0)
.await
.map(drop)
@@ -272,16 +291,18 @@ impl Db {
channel_id: ChannelId,
sender_id: UserId,
body: &str,
+ timestamp: OffsetDateTime,
) -> Result<MessageId> {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
- VALUES ($1, $2, $3, NOW()::timestamp)
+ VALUES ($1, $2, $3, $4)
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
+ .bind(timestamp)
.fetch_one(&self.0)
.await
.map(MessageId)
@@ -292,12 +313,15 @@ impl Db {
channel_id: ChannelId,
count: usize,
) -> Result<Vec<ChannelMessage>> {
- let query = "
- SELECT id, sender_id, body, sent_at
- FROM channel_messages
- WHERE channel_id = $1
+ let query = r#"
+ SELECT
+ id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+ FROM
+ channel_messages
+ WHERE
+ channel_id = $1
LIMIT $2
- ";
+ "#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
@@ -314,14 +338,29 @@ impl std::ops::Deref for Db {
}
}
-impl Channel {
- pub fn id(&self) -> ChannelId {
- ChannelId(self.id)
- }
+macro_rules! id_type {
+ ($name:ident) => {
+ #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
+ #[sqlx(transparent)]
+ #[serde(transparent)]
+ pub struct $name(pub i32);
+
+ impl $name {
+ #[allow(unused)]
+ pub fn from_proto(value: u64) -> Self {
+ Self(value as i32)
+ }
+
+ #[allow(unused)]
+ pub fn to_proto(&self) -> u64 {
+ self.0 as u64
+ }
+ }
+ };
}
-impl User {
- pub fn id(&self) -> UserId {
- UserId(self.id)
- }
-}
+id_type!(UserId);
+id_type!(OrgId);
+id_type!(ChannelId);
+id_type!(SignupId);
+id_type!(MessageId);
@@ -23,6 +23,7 @@ use tide::{
http::headers::{HeaderName, CONNECTION, UPGRADE},
Request, Response,
};
+use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, EnvelopedMessage},
@@ -33,17 +34,19 @@ type ReplicaId = u16;
#[derive(Default)]
pub struct State {
- connections: HashMap<ConnectionId, ConnectionState>,
- pub worktrees: HashMap<u64, WorktreeState>,
+ connections: HashMap<ConnectionId, Connection>,
+ pub worktrees: HashMap<u64, Worktree>,
+ channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
-struct ConnectionState {
+struct Connection {
user_id: UserId,
worktrees: HashSet<u64>,
+ channels: HashSet<ChannelId>,
}
-pub struct WorktreeState {
+pub struct Worktree {
host_connection_id: Option<ConnectionId>,
guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
active_replica_ids: HashSet<ReplicaId>,
@@ -52,7 +55,12 @@ pub struct WorktreeState {
entries: HashMap<u64, proto::Entry>,
}
-impl WorktreeState {
+#[derive(Default)]
+struct Channel {
+ connection_ids: HashSet<ConnectionId>,
+}
+
+impl Worktree {
pub fn connection_ids(&self) -> Vec<ConnectionId> {
self.guest_connection_ids
.keys()
@@ -68,14 +76,21 @@ impl WorktreeState {
}
}
+impl Channel {
+ fn connection_ids(&self) -> Vec<ConnectionId> {
+ self.connection_ids.iter().copied().collect()
+ }
+}
+
impl State {
// Add a new connection associated with a given user.
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
self.connections.insert(
connection_id,
- ConnectionState {
+ Connection {
user_id,
worktrees: Default::default(),
+ channels: Default::default(),
},
);
}
@@ -83,8 +98,13 @@ impl State {
// Remove the given connection and its association with any worktrees.
pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
let mut worktree_ids = Vec::new();
- if let Some(connection_state) = self.connections.remove(&connection_id) {
- for worktree_id in connection_state.worktrees {
+ if let Some(connection) = self.connections.remove(&connection_id) {
+ for channel_id in connection.channels {
+ if let Some(channel) = self.channels.get_mut(&channel_id) {
+ channel.connection_ids.remove(&connection_id);
+ }
+ }
+ for worktree_id in connection.worktrees {
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
if worktree.host_connection_id == Some(connection_id) {
worktree_ids.push(worktree_id);
@@ -100,28 +120,39 @@ impl State {
worktree_ids
}
+ fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+ if let Some(connection) = self.connections.get_mut(&connection_id) {
+ connection.channels.insert(channel_id);
+ self.channels
+ .entry(channel_id)
+ .or_default()
+ .connection_ids
+ .insert(connection_id);
+ }
+ }
+
// Add the given connection as a guest of the given worktree
pub fn join_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
access_token: &str,
- ) -> Option<(ReplicaId, &WorktreeState)> {
- if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
- if access_token == worktree_state.access_token {
- if let Some(connection_state) = self.connections.get_mut(&connection_id) {
- connection_state.worktrees.insert(worktree_id);
+ ) -> Option<(ReplicaId, &Worktree)> {
+ if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+ if access_token == worktree.access_token {
+ if let Some(connection) = self.connections.get_mut(&connection_id) {
+ connection.worktrees.insert(worktree_id);
}
let mut replica_id = 1;
- while worktree_state.active_replica_ids.contains(&replica_id) {
+ while worktree.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
- worktree_state.active_replica_ids.insert(replica_id);
- worktree_state
+ worktree.active_replica_ids.insert(replica_id);
+ worktree
.guest_connection_ids
.insert(connection_id, replica_id);
- Some((replica_id, worktree_state))
+ Some((replica_id, worktree))
} else {
None
}
@@ -142,7 +173,7 @@ impl State {
&self,
worktree_id: u64,
connection_id: ConnectionId,
- ) -> tide::Result<&WorktreeState> {
+ ) -> tide::Result<&Worktree> {
let worktree = self
.worktrees
.get(&worktree_id)
@@ -165,7 +196,7 @@ impl State {
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
- ) -> tide::Result<&mut WorktreeState> {
+ ) -> tide::Result<&mut Worktree> {
let worktree = self
.worktrees
.get_mut(&worktree_id)
@@ -263,7 +294,9 @@ pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer
on_message(router, rpc, state, buffer_saved);
on_message(router, rpc, state, save_buffer);
on_message(router, rpc, state, get_channels);
+ on_message(router, rpc, state, get_users);
on_message(router, rpc, state, join_channel);
+ on_message(router, rpc, state, send_channel_message);
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -373,7 +406,7 @@ async fn share_worktree(
.collect();
state.worktrees.insert(
worktree_id,
- WorktreeState {
+ Worktree {
host_connection_id: Some(request.sender_id),
guest_connection_ids: Default::default(),
active_replica_ids: Default::default(),
@@ -627,7 +660,7 @@ async fn get_channels(
channels: channels
.into_iter()
.map(|chan| proto::Channel {
- id: chan.id().0 as u64,
+ id: chan.id.to_proto(),
name: chan.name,
})
.collect(),
@@ -637,6 +670,34 @@ async fn get_channels(
Ok(())
}
+async fn get_users(
+ request: TypedEnvelope<proto::GetUsers>,
+ rpc: &Arc<Peer>,
+ state: &Arc<AppState>,
+) -> tide::Result<()> {
+ let user_id = state
+ .rpc
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let receipt = request.receipt();
+ let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+ let users = state
+ .db
+ .get_users_by_ids(user_id, user_ids)
+ .await?
+ .into_iter()
+ .map(|user| proto::User {
+ id: user.id.to_proto(),
+ github_login: user.github_login,
+ avatar_url: String::new(),
+ })
+ .collect();
+ rpc.respond(receipt, proto::GetUsersResponse { users })
+ .await?;
+ Ok(())
+}
+
async fn join_channel(
request: TypedEnvelope<proto::JoinChannel>,
rpc: &Arc<Peer>,
@@ -647,14 +708,74 @@ async fn join_channel(
.read()
.await
.user_id_for_connection(request.sender_id)?;
+ let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !state
.db
- .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32))
+ .can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
+ state
+ .rpc
+ .write()
+ .await
+ .join_channel(request.sender_id, channel_id);
+ let messages = state
+ .db
+ .get_recent_channel_messages(channel_id, 50)
+ .await?
+ .into_iter()
+ .map(|msg| proto::ChannelMessage {
+ id: msg.id.to_proto(),
+ body: msg.body,
+ timestamp: msg.sent_at.unix_timestamp() as u64,
+ sender_id: msg.sender_id.to_proto(),
+ })
+ .collect();
+ rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
+ .await?;
+ Ok(())
+}
+
+async fn send_channel_message(
+ request: TypedEnvelope<proto::SendChannelMessage>,
+ peer: &Arc<Peer>,
+ app: &Arc<AppState>,
+) -> tide::Result<()> {
+ let channel_id = ChannelId::from_proto(request.payload.channel_id);
+ let user_id;
+ let connection_ids;
+ {
+ let state = app.rpc.read().await;
+ user_id = state.user_id_for_connection(request.sender_id)?;
+ if let Some(channel) = state.channels.get(&channel_id) {
+ connection_ids = channel.connection_ids();
+ } else {
+ return Ok(());
+ }
+ }
+
+ let timestamp = OffsetDateTime::now_utc();
+ let message_id = app
+ .db
+ .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+ .await?;
+ let message = proto::ChannelMessageSent {
+ channel_id: channel_id.to_proto(),
+ message: Some(proto::ChannelMessage {
+ sender_id: user_id.to_proto(),
+ id: message_id.to_proto(),
+ body: request.payload.body,
+ timestamp: timestamp.unix_timestamp() as u64,
+ }),
+ };
+ broadcast(request.sender_id, connection_ids, |conn_id| {
+ peer.send(conn_id, message.clone())
+ })
+ .await?;
+
Ok(())
}
@@ -11,9 +11,10 @@ use rand::prelude::*;
use serde_json::json;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
+ types::time::OffsetDateTime,
Executor as _, Postgres,
};
-use std::{path::Path, sync::Arc};
+use std::{path::Path, sync::Arc, time::SystemTime};
use zed::{
editor::Editor,
fs::{FakeFs, Fs as _},
@@ -485,10 +486,15 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await;
- // Create a channel that includes these 2 users and 1 other user.
+ // Create an org that includes these 2 users and 1 other user.
let db = &server.app_state.db;
let user_id_c = db.create_user("user_c", false).await.unwrap();
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+ db.add_org_member(org_id, user_id_a, false).await.unwrap();
+ db.add_org_member(org_id, user_id_b, false).await.unwrap();
+ db.add_org_member(org_id, user_id_c, false).await.unwrap();
+
+ // Create a channel that includes all the users.
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
db.add_channel_member(channel_id, user_id_a, false)
.await
@@ -499,11 +505,21 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
db.add_channel_member(channel_id, user_id_c, false)
.await
.unwrap();
- db.create_channel_message(channel_id, user_id_c, "first message!")
- .await
- .unwrap();
-
- // let chatroom_a = ChatRoom::
+ db.create_channel_message(
+ channel_id,
+ user_id_c,
+ "first message!",
+ OffsetDateTime::now_utc(),
+ )
+ .await
+ .unwrap();
+ assert_eq!(
+ db.get_recent_channel_messages(channel_id, 50)
+ .await
+ .unwrap()[0]
+ .body,
+ "first message!"
+ );
}
struct TestServer {
@@ -24,10 +24,10 @@ message Envelope {
RemovePeer remove_peer = 19;
GetChannels get_channels = 20;
GetChannelsResponse get_channels_response = 21;
- JoinChannel join_channel = 22;
- JoinChannelResponse join_channel_response = 23;
- GetUsers get_users = 24;
- GetUsersResponse get_users_response = 25;
+ GetUsers get_users = 22;
+ GetUsersResponse get_users_response = 23;
+ JoinChannel join_channel = 24;
+ JoinChannelResponse join_channel_response = 25;
SendChannelMessage send_channel_message = 26;
ChannelMessageSent channel_message_sent = 27;
}