diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1dff0c84e927dd277d68d0b3f6e9750ba0fa4bec..13dcf4fef1d357b76ad2d11ebfc5eb589852be29 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,7 +45,10 @@ jobs: - name: Run tests run: cargo test --workspace --no-fail-fast - - name: Build collab binaries + - name: Build collab + run: cargo build -p collab + + - name: Build other binaries run: cargo build --bins --all-features bundle: diff --git a/Cargo.lock b/Cargo.lock index 84b416ad8adb07123afc1082508c14611180c324..137b59a4ff156ba2d02606faa48ee9b0e54228e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1953,6 +1953,18 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e" +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "pin-project", + "spin 0.9.4", +] + [[package]] name = "fnv" version = "1.0.7" @@ -3022,7 +3034,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" dependencies = [ - "spin", + "spin 0.5.2", ] [[package]] @@ -4725,7 +4737,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi 0.3.9", @@ -5563,6 +5575,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6002a767bff9e83f8eeecf883ecb8011875a21ae8da43bffb817a57e78cc09" +dependencies = [ + "lock_api", +] + [[package]] name = "spsc-buffer" version = "0.1.1" @@ -5583,8 +5604,7 @@ dependencies = [ [[package]] name = "sqlx" version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9249290c05928352f71c077cc44a464d880c63f26f7534728cca008e135c0428" +source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" dependencies = [ "sqlx-core", "sqlx-macros", @@ -5593,8 +5613,7 @@ dependencies = [ [[package]] name = "sqlx-core" version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcbc16ddba161afc99e14d1713a453747a2b07fc097d2009f4c300ec99286105" +source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" dependencies = [ "ahash", "atoi", @@ -5608,8 +5627,10 @@ dependencies = [ "dotenvy", "either", "event-listener", + "flume", "futures-channel", "futures-core", + "futures-executor", "futures-intrusive", "futures-util", "hashlink", @@ -5619,6 +5640,7 @@ dependencies = [ "indexmap", "itoa", "libc", + "libsqlite3-sys", "log", "md-5", "memchr", @@ -5648,8 +5670,7 @@ dependencies = [ [[package]] name = "sqlx-macros" version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b850fa514dc11f2ee85be9d055c512aa866746adfacd1cb42d867d68e6a5b0d9" +source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" dependencies = [ "dotenvy", "either", @@ -5657,6 +5678,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", + "serde_json", "sha2 0.10.6", "sqlx-core", "sqlx-rt", @@ -5667,8 +5689,7 @@ dependencies = [ [[package]] name = "sqlx-rt" version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24c5b2d25fa654cc5f841750b8e1cdedbe21189bf9a9382ee90bfa9dd3562396" +source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" dependencies = [ "once_cell", "tokio", diff --git a/crates/client/src/channel.rs b/crates/client/src/channel.rs deleted file mode 100644 index 7b4f6073ceb34a653130ba64c571734e563687fd..0000000000000000000000000000000000000000 --- a/crates/client/src/channel.rs +++ /dev/null @@ -1,820 +0,0 @@ -use super::{ - proto, - user::{User, UserStore}, - Client, Status, Subscription, TypedEnvelope, -}; -use anyhow::{anyhow, Context, Result}; -use futures::lock::Mutex; -use gpui::{ - AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, -}; -use postage::prelude::Stream; -use rand::prelude::*; -use std::{ - collections::{HashMap, HashSet}, - mem, - ops::Range, - sync::Arc, -}; -use sum_tree::{Bias, SumTree}; -use time::OffsetDateTime; -use util::{post_inc, ResultExt as _, TryFutureExt}; - -pub struct ChannelList { - available_channels: Option>, - channels: HashMap>, - client: Arc, - user_store: ModelHandle, - _task: Task>, -} - -#[derive(Clone, Debug, PartialEq)] -pub struct ChannelDetails { - pub id: u64, - pub name: String, -} - -pub struct Channel { - details: ChannelDetails, - messages: SumTree, - loaded_all_messages: bool, - next_pending_message_id: usize, - user_store: ModelHandle, - rpc: Arc, - outgoing_messages_lock: Arc>, - rng: StdRng, - _subscription: Subscription, -} - -#[derive(Clone, Debug)] -pub struct ChannelMessage { - pub id: ChannelMessageId, - pub body: String, - pub timestamp: OffsetDateTime, - pub sender: Arc, - pub nonce: u128, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum ChannelMessageId { - Saved(u64), - Pending(usize), -} - -#[derive(Clone, Debug, Default)] -pub struct ChannelMessageSummary { - max_id: ChannelMessageId, - count: usize, -} - -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -struct Count(usize); - -pub enum ChannelListEvent {} - -#[derive(Clone, Debug, PartialEq)] -pub enum ChannelEvent { - MessagesUpdated { - old_range: Range, - new_count: usize, - }, -} - -impl Entity for ChannelList { - type Event = ChannelListEvent; -} - -impl ChannelList { - pub fn new( - user_store: ModelHandle, - rpc: Arc, - cx: &mut ModelContext, - ) -> Self { - let _task = cx.spawn_weak(|this, mut cx| { - let rpc = rpc.clone(); - async move { - let mut status = rpc.status(); - while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) { - match status { - Status::Connected { .. } => { - let response = rpc - .request(proto::GetChannels {}) - .await - .context("failed to fetch available channels")?; - this.update(&mut cx, |this, cx| { - this.available_channels = - Some(response.channels.into_iter().map(Into::into).collect()); - - let mut to_remove = Vec::new(); - for (channel_id, channel) in &this.channels { - if let Some(channel) = channel.upgrade(cx) { - channel.update(cx, |channel, cx| channel.rejoin(cx)) - } else { - to_remove.push(*channel_id); - } - } - - for channel_id in to_remove { - this.channels.remove(&channel_id); - } - cx.notify(); - }); - } - Status::SignedOut { .. } => { - this.update(&mut cx, |this, cx| { - this.available_channels = None; - this.channels.clear(); - cx.notify(); - }); - } - _ => {} - } - } - Ok(()) - } - .log_err() - }); - - Self { - available_channels: None, - channels: Default::default(), - user_store, - client: rpc, - _task, - } - } - - pub fn available_channels(&self) -> Option<&[ChannelDetails]> { - self.available_channels.as_deref() - } - - pub fn get_channel( - &mut self, - id: u64, - cx: &mut MutableAppContext, - ) -> Option> { - if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) { - return Some(channel); - } - - let channels = self.available_channels.as_ref()?; - let details = channels.iter().find(|details| details.id == id)?.clone(); - let channel = cx.add_model(|cx| { - Channel::new(details, self.user_store.clone(), self.client.clone(), cx) - }); - self.channels.insert(id, channel.downgrade()); - Some(channel) - } -} - -impl Entity for Channel { - type Event = ChannelEvent; - - fn release(&mut self, _: &mut MutableAppContext) { - self.rpc - .send(proto::LeaveChannel { - channel_id: self.details.id, - }) - .log_err(); - } -} - -impl Channel { - pub fn init(rpc: &Arc) { - rpc.add_model_message_handler(Self::handle_message_sent); - } - - pub fn new( - details: ChannelDetails, - user_store: ModelHandle, - rpc: Arc, - cx: &mut ModelContext, - ) -> Self { - let _subscription = rpc.add_model_for_remote_entity(details.id, cx); - - { - let user_store = user_store.clone(); - let rpc = rpc.clone(); - let channel_id = details.id; - cx.spawn(|channel, mut cx| { - async move { - let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - let loaded_all_messages = response.done; - - channel.update(&mut cx, |channel, cx| { - channel.insert_messages(messages, cx); - channel.loaded_all_messages = loaded_all_messages; - }); - - Ok(()) - } - .log_err() - }) - .detach(); - } - - Self { - details, - user_store, - rpc, - outgoing_messages_lock: Default::default(), - messages: Default::default(), - loaded_all_messages: false, - next_pending_message_id: 0, - rng: StdRng::from_entropy(), - _subscription, - } - } - - pub fn name(&self) -> &str { - &self.details.name - } - - pub fn send_message( - &mut self, - body: String, - cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { - Err(anyhow!("message body can't be empty"))?; - } - - let current_user = self - .user_store - .read(cx) - .current_user() - .ok_or_else(|| anyhow!("current_user is not present"))?; - - let channel_id = self.details.id; - let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); - let nonce = self.rng.gen(); - self.insert_messages( - SumTree::from_item( - ChannelMessage { - id: pending_id, - body: body.clone(), - sender: current_user, - timestamp: OffsetDateTime::now_utc(), - nonce, - }, - &(), - ), - cx, - ); - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let outgoing_messages_lock = self.outgoing_messages_lock.clone(); - Ok(cx.spawn(|this, mut cx| async move { - let outgoing_message_guard = outgoing_messages_lock.lock().await; - let request = rpc.request(proto::SendChannelMessage { - channel_id, - body, - nonce: Some(nonce.into()), - }); - let response = request.await?; - drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) - }) - })) - } - - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.details.id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - Ok(()) - } - .log_err() - }) - .detach(); - return true; - } - } - false - } - - pub fn rejoin(&mut self, cx: &mut ModelContext) { - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let channel_id = self.details.id; - cx.spawn(|this, mut cx| { - async move { - let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; - let loaded_all_messages = response.done; - - let pending_messages = this.update(&mut cx, |this, cx| { - if let Some((first_new_message, last_old_message)) = - messages.first().zip(this.messages.last()) - { - if first_new_message.id > last_old_message.id { - let old_messages = mem::take(&mut this.messages); - cx.emit(ChannelEvent::MessagesUpdated { - old_range: 0..old_messages.summary().count, - new_count: 0, - }); - this.loaded_all_messages = loaded_all_messages; - } - } - - this.insert_messages(messages, cx); - if loaded_all_messages { - this.loaded_all_messages = loaded_all_messages; - } - - this.pending_messages().cloned().collect::>() - }); - - for pending_message in pending_messages { - let request = rpc.request(proto::SendChannelMessage { - channel_id, - body: pending_message.body, - nonce: Some(pending_message.nonce.into()), - }); - let response = request.await?; - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - }); - } - - Ok(()) - } - .log_err() - }) - .detach(); - } - - pub fn message_count(&self) -> usize { - self.messages.summary().count - } - - pub fn messages(&self) -> &SumTree { - &self.messages - } - - pub fn message(&self, ix: usize) -> &ChannelMessage { - let mut cursor = self.messages.cursor::(); - cursor.seek(&Count(ix), Bias::Right, &()); - cursor.item().unwrap() - } - - pub fn messages_in_range(&self, range: Range) -> impl Iterator { - let mut cursor = self.messages.cursor::(); - cursor.seek(&Count(range.start), Bias::Right, &()); - cursor.take(range.len()) - } - - pub fn pending_messages(&self) -> impl Iterator { - let mut cursor = self.messages.cursor::(); - cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); - cursor - } - - async fn handle_message_sent( - this: ModelHandle, - message: TypedEnvelope, - _: Arc, - mut cx: AsyncAppContext, - ) -> Result<()> { - let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); - let message = message - .payload - .message - .ok_or_else(|| anyhow!("empty message"))?; - - let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx) - }); - - Ok(()) - } - - fn insert_messages(&mut self, messages: SumTree, cx: &mut ModelContext) { - if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { - let nonces = messages - .cursor::<()>() - .map(|m| m.nonce) - .collect::>(); - - let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(); - let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); - let start_ix = old_cursor.start().1 .0; - let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); - let removed_count = removed_messages.summary().count; - let new_count = messages.summary().count; - let end_ix = start_ix + removed_count; - - new_messages.push_tree(messages, &()); - - let mut ranges = Vec::>::new(); - if new_messages.last().unwrap().is_pending() { - new_messages.push_tree(old_cursor.suffix(&()), &()); - } else { - new_messages.push_tree( - old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), - &(), - ); - - while let Some(message) = old_cursor.item() { - let message_ix = old_cursor.start().1 .0; - if nonces.contains(&message.nonce) { - if ranges.last().map_or(false, |r| r.end == message_ix) { - ranges.last_mut().unwrap().end += 1; - } else { - ranges.push(message_ix..message_ix + 1); - } - } else { - new_messages.push(message.clone(), &()); - } - old_cursor.next(&()); - } - } - - drop(old_cursor); - self.messages = new_messages; - - for range in ranges.into_iter().rev() { - cx.emit(ChannelEvent::MessagesUpdated { - old_range: range, - new_count: 0, - }); - } - cx.emit(ChannelEvent::MessagesUpdated { - old_range: start_ix..end_ix, - new_count, - }); - cx.notify(); - } - } -} - -async fn messages_from_proto( - proto_messages: Vec, - user_store: &ModelHandle, - cx: &mut AsyncAppContext, -) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } - let mut result = SumTree::new(); - result.extend(messages, &()); - Ok(result) -} - -impl From for ChannelDetails { - fn from(message: proto::Channel) -> Self { - Self { - id: message.id, - name: message.name, - } - } -} - -impl ChannelMessage { - pub async fn from_proto( - message: proto::ChannelMessage, - user_store: &ModelHandle, - cx: &mut AsyncAppContext, - ) -> Result { - let sender = user_store - .update(cx, |user_store, cx| { - user_store.get_user(message.sender_id, cx) - }) - .await?; - Ok(ChannelMessage { - id: ChannelMessageId::Saved(message.id), - body: message.body, - timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, - sender, - nonce: message - .nonce - .ok_or_else(|| anyhow!("nonce is required"))? - .into(), - }) - } - - pub fn is_pending(&self) -> bool { - matches!(self.id, ChannelMessageId::Pending(_)) - } -} - -impl sum_tree::Item for ChannelMessage { - type Summary = ChannelMessageSummary; - - fn summary(&self) -> Self::Summary { - ChannelMessageSummary { - max_id: self.id, - count: 1, - } - } -} - -impl Default for ChannelMessageId { - fn default() -> Self { - Self::Saved(0) - } -} - -impl sum_tree::Summary for ChannelMessageSummary { - type Context = (); - - fn add_summary(&mut self, summary: &Self, _: &()) { - self.max_id = summary.max_id; - self.count += summary.count; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId { - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - debug_assert!(summary.max_id > *self); - *self = summary.max_id; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - self.0 += summary.count; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test::{FakeHttpClient, FakeServer}; - use gpui::TestAppContext; - - #[gpui::test] - async fn test_channel_messages(cx: &mut TestAppContext) { - cx.foreground().forbid_parking(); - - let user_id = 5; - let http_client = FakeHttpClient::with_404_response(); - let client = cx.update(|cx| Client::new(http_client.clone(), cx)); - let server = FakeServer::for_client(user_id, &client, cx).await; - - Channel::init(&client); - let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); - - let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); - channel_list.read_with(cx, |list, _| assert_eq!(list.available_channels(), None)); - - // Get the available channels. - let get_channels = server.receive::().await.unwrap(); - server - .respond( - get_channels.receipt(), - proto::GetChannelsResponse { - channels: vec![proto::Channel { - id: 5, - name: "the-channel".to_string(), - }], - }, - ) - .await; - channel_list.next_notification(cx).await; - channel_list.read_with(cx, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: 5, - name: "the-channel".into(), - }] - ) - }); - - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - }], - }, - ) - .await; - - // Join a channel and populate its existing messages. - let channel = channel_list - .update(cx, |list, cx| { - let channel_id = list.available_channels().unwrap()[0].id; - list.get_channel(channel_id, cx) - }) - .unwrap(); - channel.read_with(cx, |channel, _| assert!(channel.messages().is_empty())); - let join_channel = server.receive::().await.unwrap(); - server - .respond( - join_channel.receipt(), - proto::JoinChannelResponse { - messages: vec![ - proto::ChannelMessage { - id: 10, - body: "a".into(), - timestamp: 1000, - sender_id: 5, - nonce: Some(1.into()), - }, - proto::ChannelMessage { - id: 11, - body: "b".into(), - timestamp: 1001, - sender_id: 6, - nonce: Some(2.into()), - }, - ], - done: false, - }, - ) - .await; - - // Client requests all users for the received messages - let mut get_users = server.receive::().await.unwrap(); - get_users.payload.user_ids.sort(); - assert_eq!(get_users.payload.user_ids, vec![6]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 6, - github_login: "maxbrunsfeld".into(), - avatar_url: "http://avatar.com/maxbrunsfeld".into(), - }], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 0..0, - new_count: 2, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("nathansobo".into(), "a".into()), - ("maxbrunsfeld".into(), "b".into()) - ] - ); - }); - - // Receive a new message. - server.send(proto::ChannelMessageSent { - channel_id: channel.read_with(cx, |channel, _| channel.details.id), - message: Some(proto::ChannelMessage { - id: 12, - body: "c".into(), - timestamp: 1002, - sender_id: 7, - nonce: Some(3.into()), - }), - }); - - // Client requests user for message since they haven't seen them yet - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![7]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 7, - github_login: "as-cii".into(), - avatar_url: "http://avatar.com/as-cii".into(), - }], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 2..2, - new_count: 1, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(2..3) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[("as-cii".into(), "c".into())] - ) - }); - - // Scroll up to view older messages. - channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); - }); - let get_messages = server.receive::().await.unwrap(); - assert_eq!(get_messages.payload.channel_id, 5); - assert_eq!(get_messages.payload.before_message_id, 10); - server - .respond( - get_messages.receipt(), - proto::GetChannelMessagesResponse { - done: true, - messages: vec![ - proto::ChannelMessage { - id: 8, - body: "y".into(), - timestamp: 998, - sender_id: 5, - nonce: Some(4.into()), - }, - proto::ChannelMessage { - id: 9, - body: "z".into(), - timestamp: 999, - sender_id: 6, - nonce: Some(5.into()), - }, - ], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 0..0, - new_count: 2, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("nathansobo".into(), "y".into()), - ("maxbrunsfeld".into(), "z".into()) - ] - ); - }); - } -} diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 587961f2a723970479e521fee77f5903af04025d..c943b274172c8264ee311270d4575973f945e6cc 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,7 +1,6 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; -pub mod channel; pub mod http; pub mod telemetry; pub mod user; @@ -44,7 +43,6 @@ use thiserror::Error; use url::Url; use util::{ResultExt, TryFutureExt}; -pub use channel::*; pub use rpc::*; pub use user::*; diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 90b9d9a19a859651ef52dd247b0b21be61c172a2..7456cb5598f64bd497fd2b73252ac40219e439b6 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -50,8 +50,9 @@ tracing-log = "0.1.3" tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } [dependencies.sqlx] -version = "0.6" -features = ["runtime-tokio-rustls", "postgres", "time", "uuid"] +git = "https://github.com/launchbadge/sqlx" +rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" +features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } @@ -78,5 +79,10 @@ lazy_static = "1.4" serde_json = { version = "1.0", features = ["preserve_order"] } unindent = "0.1" +[dev-dependencies.sqlx] +git = "https://github.com/launchbadge/sqlx" +rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" +features = ["sqlite"] + [features] seed-support = ["clap", "lipsum", "reqwest"] diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..63d2661de5d5d2262b371de651b434f6fe1a6c38 --- /dev/null +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -0,0 +1,41 @@ +CREATE TABLE IF NOT EXISTS "users" ( + "id" INTEGER PRIMARY KEY, + "github_login" VARCHAR, + "admin" BOOLEAN, + "email_address" VARCHAR(255) DEFAULT NULL, + "invite_code" VARCHAR(64), + "invite_count" INTEGER NOT NULL DEFAULT 0, + "inviter_id" INTEGER REFERENCES users (id), + "connected_once" BOOLEAN NOT NULL DEFAULT false, + "created_at" TIMESTAMP NOT NULL DEFAULT now, + "metrics_id" VARCHAR(255), + "github_user_id" INTEGER +); +CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); +CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); +CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); +CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); + +CREATE TABLE IF NOT EXISTS "access_tokens" ( + "id" INTEGER PRIMARY KEY, + "user_id" INTEGER REFERENCES users (id), + "hash" VARCHAR(128) +); +CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); + +CREATE TABLE IF NOT EXISTS "contacts" ( + "id" INTEGER PRIMARY KEY, + "user_id_a" INTEGER REFERENCES users (id) NOT NULL, + "user_id_b" INTEGER REFERENCES users (id) NOT NULL, + "a_to_b" BOOLEAN NOT NULL, + "should_notify" BOOLEAN NOT NULL, + "accepted" BOOLEAN NOT NULL +); +CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); +CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); + +CREATE TABLE IF NOT EXISTS "projects" ( + "id" INTEGER PRIMARY KEY, + "host_user_id" INTEGER REFERENCES users (id) NOT NULL, + "unregistered" BOOLEAN NOT NULL DEFAULT false +); diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index fbf45a379925edc8cf285f7d65e439886dd73b81..5fcdc5fcfdf59a983d3d4c04d98242eb3d97fa41 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,6 +1,6 @@ use crate::{ auth, - db::{Invite, NewUserParams, ProjectId, Signup, User, UserId, WaitlistSummary}, + db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary}, rpc::{self, ResultExt}, AppState, Error, Result, }; @@ -16,9 +16,7 @@ use axum::{ }; use axum_extra::response::ErasedJson; use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{sync::Arc, time::Duration}; -use time::OffsetDateTime; +use std::sync::Arc; use tower::ServiceBuilder; use tracing::instrument; @@ -32,16 +30,6 @@ pub fn routes(rpc_server: Arc, state: Arc) -> Router, - Extension(app): Extension>, -) -> Result { - let summary = app - .db - .get_top_users_activity_summary(params.start..params.end, 100) - .await?; - Ok(ErasedJson::pretty(summary)) -} - -async fn get_user_activity_timeline( - Path(user_id): Path, - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let summary = app - .db - .get_user_activity_timeline(params.start..params.end, UserId(user_id)) - .await?; - Ok(ErasedJson::pretty(summary)) -} - -#[derive(Deserialize)] -struct ActiveUserCountParams { - #[serde(flatten)] - period: TimePeriodParams, - durations_in_minutes: String, - #[serde(default)] - only_collaborative: bool, -} - -#[derive(Serialize)] -struct ActiveUserSet { - active_time_in_minutes: u64, - user_count: usize, -} - -async fn get_active_user_counts( - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let durations_in_minutes = params.durations_in_minutes.split(','); - let mut user_sets = Vec::new(); - for duration in durations_in_minutes { - let duration = duration - .parse() - .map_err(|_| anyhow!("invalid duration: {duration}"))?; - user_sets.push(ActiveUserSet { - active_time_in_minutes: duration, - user_count: app - .db - .get_active_user_count( - params.period.start..params.period.end, - Duration::from_secs(duration * 60), - params.only_collaborative, - ) - .await?, - }) - } - Ok(ErasedJson::pretty(user_sets)) -} - -#[derive(Deserialize)] -struct GetProjectMetadataParams { - project_id: u64, -} - -async fn get_project_metadata( - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let extensions = app - .db - .get_project_extensions(ProjectId::from_proto(params.project_id)) - .await?; - Ok(ErasedJson::pretty(json!({ "extensions": extensions }))) -} - #[derive(Deserialize)] struct CreateAccessTokenQueryParams { public_key: String, diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 9081fe1f1e793bab5e7825941ce198f8c0a14a67..63f032f7e65d17f454d603b26c6206c81eacdf65 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -75,7 +75,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result { +pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result { let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 35fad50e12681a06d6e4fb38b34097503b2c6e98..10da609d57b9b7cfe04927b681b378c07e099b4b 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,1609 +1,1192 @@ use crate::{Error, Result}; -use anyhow::{anyhow, Context}; -use async_trait::async_trait; +use anyhow::anyhow; use axum::http::StatusCode; use collections::HashMap; use futures::StreamExt; use serde::{Deserialize, Serialize}; -pub use sqlx::postgres::PgPoolOptions as DbOptions; use sqlx::{ migrate::{Migrate as _, Migration, MigrationSource}, types::Uuid, - FromRow, QueryBuilder, + FromRow, }; -use std::{cmp, ops::Range, path::Path, time::Duration}; +use std::{path::Path, time::Duration}; use time::{OffsetDateTime, PrimitiveDateTime}; -#[async_trait] -pub trait Db: Send + Sync { - async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result; - async fn get_all_users(&self, page: u32, limit: u32) -> Result>; - async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result>; - async fn get_user_by_id(&self, id: UserId) -> Result>; - async fn get_user_metrics_id(&self, id: UserId) -> Result; - async fn get_users_by_ids(&self, ids: Vec) -> Result>; - async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result>; - async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result>; - async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>; - async fn destroy_user(&self, id: UserId) -> Result<()>; - - async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>; - async fn get_invite_code_for_user(&self, id: UserId) -> Result>; - async fn get_user_for_invite_code(&self, code: &str) -> Result; - async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result; - - async fn create_signup(&self, signup: Signup) -> Result<()>; - async fn get_waitlist_summary(&self) -> Result; - async fn get_unsent_invites(&self, count: usize) -> Result>; - async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>; - async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result>; - - /// Registers a new project for the given user. - async fn register_project(&self, host_user_id: UserId) -> Result; - - /// Unregisters a project for the given project id. - async fn unregister_project(&self, project_id: ProjectId) -> Result<()>; - - /// Update file counts by extension for the given project and worktree. - async fn update_worktree_extensions( - &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()>; - - /// Get the file counts on the given project keyed by their worktree and extension. - async fn get_project_extensions( - &self, - project_id: ProjectId, - ) -> Result>>; - - /// Record which users have been active in which projects during - /// a given period of time. - async fn record_user_activity( - &self, - time_period: Range, - active_projects: &[(UserId, ProjectId)], - ) -> Result<()>; +#[cfg(test)] +pub type DefaultDb = Db; - /// Get the number of users who have been active in the given - /// time period for at least the given time duration. - async fn get_active_user_count( - &self, - time_period: Range, - min_duration: Duration, - only_collaborative: bool, - ) -> Result; - - /// Get the users that have been most active during the given time period, - /// along with the amount of time they have been active in each project. - async fn get_top_users_activity_summary( - &self, - time_period: Range, - max_user_count: usize, - ) -> Result>; +#[cfg(not(test))] +pub type DefaultDb = Db; - /// Get the project activity for the given user and time period. - async fn get_user_activity_timeline( - &self, - time_period: Range, - user_id: UserId, - ) -> Result>; +pub struct Db { + pool: sqlx::Pool, + #[cfg(test)] + background: Option>, + #[cfg(test)] + runtime: Option, +} - async fn get_contacts(&self, id: UserId) -> Result>; - async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result; - async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; - async fn dismiss_contact_notification( - &self, - responder_id: UserId, - requester_id: UserId, - ) -> Result<()>; - async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()>; +macro_rules! test_support { + ($self:ident, { $($token:tt)* }) => {{ + let body = async { + $($token)* + }; - async fn create_access_token_hash( - &self, - user_id: UserId, - access_token_hash: &str, - max_access_token_count: usize, - ) -> Result<()>; - async fn get_access_token_hashes(&self, user_id: UserId) -> Result>; - - #[cfg(any(test, feature = "seed-support"))] - async fn find_org_by_slug(&self, slug: &str) -> Result>; - #[cfg(any(test, feature = "seed-support"))] - async fn create_org(&self, name: &str, slug: &str) -> Result; - #[cfg(any(test, feature = "seed-support"))] - async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>; - #[cfg(any(test, feature = "seed-support"))] - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result; - #[cfg(any(test, feature = "seed-support"))] - - async fn get_org_channels(&self, org_id: OrgId) -> Result>; - async fn get_accessible_channels(&self, user_id: UserId) -> Result>; - async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId) - -> Result; - - #[cfg(any(test, feature = "seed-support"))] - async fn add_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - is_admin: bool, - ) -> Result<()>; - async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result; - async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result>; + if cfg!(test) { + #[cfg(not(test))] + unreachable!(); - #[cfg(test)] - async fn teardown(&self, url: &str); + #[cfg(test)] + if let Some(background) = $self.background.as_ref() { + background.simulate_random_delay().await; + } - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb>; + #[cfg(test)] + $self.runtime.as_ref().unwrap().block_on(body) + } else { + body.await + } + }}; } -#[cfg(any(test, debug_assertions))] -pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = - Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); +pub trait RowsAffected { + fn rows_affected(&self) -> u64; +} -#[cfg(not(any(test, debug_assertions)))] -pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None; +#[cfg(test)] +impl RowsAffected for sqlx::sqlite::SqliteQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} -pub struct PostgresDb { - pool: sqlx::PgPool, +impl RowsAffected for sqlx::postgres::PgQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } } -impl PostgresDb { +#[cfg(test)] +impl Db { pub async fn new(url: &str, max_connections: u32) -> Result { - let pool = DbOptions::new() + use std::str::FromStr as _; + let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) + .unwrap() + .create_if_missing(true) + .shared_cache(true); + let pool = sqlx::sqlite::SqlitePoolOptions::new() + .min_connections(2) .max_connections(max_connections) - .connect(url) - .await - .context("failed to connect to postgres database")?; - Ok(Self { pool }) + .connect_with(options) + .await?; + Ok(Self { + pool, + background: None, + runtime: None, + }) } - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - - let mut conn = self.pool.acquire().await?; - - conn.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = conn - .list_applied_migrations() - .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); - - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; - } - } - None => { - let elapsed = conn.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } - } - } - - Ok(new_migrations) + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + test_support!(self, { + let query = " + SELECT users.* + FROM users + WHERE users.id IN (SELECT value from json_each($1)) + "; + Ok(sqlx::query_as(query) + .bind(&serde_json::json!(ids)) + .fetch_all(&self.pool) + .await?) + }) } - pub fn fuzzy_like_string(string: &str) -> String { - let mut result = String::with_capacity(string.len() * 2 + 1); - for c in string.chars() { - if c.is_alphanumeric() { - result.push('%'); - result.push(c); - } - } - result.push('%'); - result + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + test_support!(self, { + let query = " + SELECT metrics_id + FROM users + WHERE id = $1 + "; + Ok(sqlx::query_scalar(query) + .bind(id) + .fetch_one(&self.pool) + .await?) + }) } -} - -#[async_trait] -impl Db for PostgresDb { - // users - async fn create_user( + pub async fn create_user( &self, email_address: &str, admin: bool, params: NewUserParams, ) -> Result { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id::text - "; - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&self.pool) - .await?; - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, - }) - } + test_support!(self, { + let query = " + INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login + RETURNING id, metrics_id + "; - async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; - Ok(sqlx::query_as(query) - .bind(limit as i32) - .bind((page * limit) as i32) - .fetch_all(&self.pool) - .await?) + let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) + .bind(email_address) + .bind(params.github_login) + .bind(params.github_user_id) + .bind(admin) + .bind(Uuid::new_v4().to_string()) + .fetch_one(&self.pool) + .await?; + Ok(NewUserResult { + user_id, + metrics_id, + signup_device_id: None, + inviting_user_id: None, + }) + }) } - async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&self.pool) - .await?) + pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { + unimplemented!() } - async fn get_user_by_id(&self, id: UserId) -> Result> { - let users = self.get_users_by_ids(vec![id]).await?; - Ok(users.into_iter().next()) + pub async fn create_user_from_invite( + &self, + _invite: &Invite, + _user: NewUserParams, + ) -> Result> { + unimplemented!() } - async fn get_user_metrics_id(&self, id: UserId) -> Result { - let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&self.pool) - .await?) + pub async fn create_signup(&self, _signup: Signup) -> Result<()> { + unimplemented!() } - async fn get_users_by_ids(&self, ids: Vec) -> Result> { - let ids = ids.into_iter().map(|id| id.0).collect::>(); - let query = " - SELECT users.* - FROM users - WHERE users.id = ANY ($1) - "; - Ok(sqlx::query_as(query) - .bind(&ids) - .fetch_all(&self.pool) - .await?) + pub async fn create_invite_from_code( + &self, + _code: &str, + _email_address: &str, + _device_id: Option<&str>, + ) -> Result { + unimplemented!() } - async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result> { - let query = format!( - " - SELECT users.* - FROM users - WHERE invite_count = 0 - AND inviter_id IS{} NULL - ", - if invited_by_another_user { " NOT" } else { "" } - ); - - Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { + unimplemented!() } +} - async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - if let Some(github_user_id) = github_user_id { - let mut user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_login = $1 - WHERE github_user_id = $2 - RETURNING * - ", - ) - .bind(github_login) - .bind(github_user_id) - .fetch_optional(&self.pool) +impl Db { + pub async fn new(url: &str, max_connections: u32) -> Result { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(max_connections) + .connect(url) .await?; - - if user.is_none() { - user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_user_id = $1 - WHERE github_login = $2 - RETURNING * - ", - ) - .bind(github_user_id) - .bind(github_login) - .fetch_optional(&self.pool) - .await?; - } - - Ok(user) - } else { - Ok(sqlx::query_as( - " - SELECT * FROM users - WHERE github_login = $1 - LIMIT 1 - ", - ) - .bind(github_login) - .fetch_optional(&self.pool) - .await?) - } - } - - async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + Ok(Self { + pool, + #[cfg(test)] + background: None, + #[cfg(test)] + runtime: None, + }) } - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(connected_once) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + #[cfg(test)] + pub fn teardown(&self, url: &str) { + self.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); + "; + sqlx::query(query).execute(&self.pool).await.log_err(); + self.pool.close().await; + ::drop_database(url) + .await + .log_err(); + }) } - async fn destroy_user(&self, id: UserId) -> Result<()> { - let query = "DELETE FROM access_tokens WHERE user_id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?; - let query = "DELETE FROM users WHERE id = $1;"; - Ok(sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + test_support!(self, { + let like_string = Self::fuzzy_like_string(name_query); + let query = " + SELECT users.* + FROM users + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 + "; + Ok(sqlx::query_as(query) + .bind(like_string) + .bind(name_query) + .bind(limit as i32) + .fetch_all(&self.pool) + .await?) + }) } - // signups - - async fn create_signup(&self, signup: Signup) -> Result<()> { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id - ) - VALUES - ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8) - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - .bind(&signup.editor_features) - .bind(&signup.programming_languages) - .bind(&signup.device_id) - .execute(&self.pool) - .await?; - Ok(()) + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + test_support!(self, { + let query = " + SELECT users.* + FROM users + WHERE users.id = ANY ($1) + "; + Ok(sqlx::query_as(query) + .bind(&ids.into_iter().map(|id| id.0).collect::>()) + .fetch_all(&self.pool) + .await?) + }) } - async fn get_waitlist_summary(&self) -> Result { - Ok(sqlx::query_as( - " - SELECT - COUNT(*) as count, - COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, - COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, - COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, - COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count - FROM ( - SELECT * - FROM signups - WHERE - NOT email_confirmation_sent - ) AS unsent - ", - ) - .fetch_one(&self.pool) - .await?) + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + test_support!(self, { + let query = " + SELECT metrics_id::text + FROM users + WHERE id = $1 + "; + Ok(sqlx::query_scalar(query) + .bind(id) + .fetch_one(&self.pool) + .await?) + }) } - async fn get_unsent_invites(&self, count: usize) -> Result> { - Ok(sqlx::query_as( - " - SELECT - email_address, email_confirmation_code - FROM signups - WHERE - NOT email_confirmation_sent AND - (platform_mac OR platform_unknown) - LIMIT $1 - ", - ) - .bind(count as i32) - .fetch_all(&self.pool) - .await?) - } + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + test_support!(self, { + let query = " + INSERT INTO users (email_address, github_login, github_user_id, admin) + VALUES ($1, $2, $3, $4) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login + RETURNING id, metrics_id::text + "; - async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = 't' - WHERE email_address = ANY ($1) - ", - ) - .bind( - &invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(), - ) - .execute(&self.pool) - .await?; - Ok(()) + let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) + .bind(email_address) + .bind(params.github_login) + .bind(params.github_user_id) + .bind(admin) + .fetch_one(&self.pool) + .await?; + Ok(NewUserResult { + user_id, + metrics_id, + signup_device_id: None, + inviting_user_id: None, + }) + }) } - async fn create_user_from_invite( + pub async fn create_user_from_invite( &self, invite: &Invite, user: NewUserParams, ) -> Result> { - let mut tx = self.pool.begin().await?; - - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if existing_user_id.is_some() { - return Ok(None); - } + test_support!(self, { + let mut tx = self.pool.begin().await?; - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, 'f', $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; - - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; - - if let Some(inviting_user_id) = inviting_user_id { - let id: Option = sqlx::query_scalar( + let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( + i32, + Option, + Option, + Option, + ) = sqlx::query_as( " - UPDATE users - SET invite_count = invite_count - 1 - WHERE id = $1 AND invite_count > 0 - RETURNING id + SELECT id, user_id, inviting_user_id, device_id + FROM signups + WHERE + email_address = $1 AND + email_confirmation_code = $2 ", ) - .bind(&inviting_user_id) + .bind(&invite.email_address) + .bind(&invite.email_confirmation_code) .fetch_optional(&mut tx) - .await?; + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - if id.is_none() { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; + if existing_user_id.is_some() { + return Ok(None); } - sqlx::query( + let (user_id, metrics_id): (UserId, String) = sqlx::query_as( " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) + INSERT INTO users + (email_address, github_login, github_user_id, admin, invite_count, invite_code) VALUES - ($1, $2, 't', 't', 't') - ON CONFLICT DO NOTHING + ($1, $2, $3, FALSE, $4, $5) + ON CONFLICT (github_login) DO UPDATE SET + email_address = excluded.email_address, + github_user_id = excluded.github_user_id, + admin = excluded.admin + RETURNING id, metrics_id::text ", ) - .bind(inviting_user_id) - .bind(user_id) - .execute(&mut tx) + .bind(&invite.email_address) + .bind(&user.github_login) + .bind(&user.github_user_id) + .bind(&user.invite_count) + .bind(random_invite_code()) + .fetch_one(&mut tx) .await?; - } - - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) - } - // invite codes - - async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - let mut tx = self.pool.begin().await?; - if count > 0 { sqlx::query( " - UPDATE users - SET invite_code = $1 - WHERE id = $2 AND invite_code IS NULL - ", + UPDATE signups + SET user_id = $1 + WHERE id = $2 + ", ) - .bind(random_invite_code()) - .bind(id) + .bind(&user_id) + .bind(&signup_id) .execute(&mut tx) .await?; - } - sqlx::query( - " - UPDATE users - SET invite_count = $1 - WHERE id = $2 - ", - ) - .bind(count as i32) - .bind(id) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - } + if let Some(inviting_user_id) = inviting_user_id { + let id: Option = sqlx::query_scalar( + " + UPDATE users + SET invite_count = invite_count - 1 + WHERE id = $1 AND invite_count > 0 + RETURNING id + ", + ) + .bind(&inviting_user_id) + .fetch_optional(&mut tx) + .await?; - async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - let result: Option<(String, i32)> = sqlx::query_as( - " - SELECT invite_code, invite_count - FROM users - WHERE id = $1 AND invite_code IS NOT NULL - ", - ) - .bind(id) - .fetch_optional(&self.pool) - .await?; - if let Some((code, count)) = result { - Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) - } else { - Ok(None) - } + if id.is_none() { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + sqlx::query( + " + INSERT INTO contacts + (user_id_a, user_id_b, a_to_b, should_notify, accepted) + VALUES + ($1, $2, TRUE, TRUE, TRUE) + ON CONFLICT DO NOTHING + ", + ) + .bind(inviting_user_id) + .bind(user_id) + .execute(&mut tx) + .await?; + } + + tx.commit().await?; + Ok(Some(NewUserResult { + user_id, + metrics_id, + inviting_user_id, + signup_device_id, + })) + }) } - async fn get_user_for_invite_code(&self, code: &str) -> Result { - sqlx::query_as( - " - SELECT * - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), + pub async fn create_signup(&self, signup: Signup) -> Result<()> { + test_support!(self, { + sqlx::query( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + editor_features, + programming_languages, + device_id + ) + VALUES + ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8) + RETURNING id + ", ) + .bind(&signup.email_address) + .bind(&random_email_confirmation_code()) + .bind(&signup.platform_linux) + .bind(&signup.platform_mac) + .bind(&signup.platform_windows) + .bind(&signup.editor_features) + .bind(&signup.programming_languages) + .bind(&signup.device_id) + .execute(&self.pool) + .await?; + Ok(()) }) } - async fn create_invite_from_code( + pub async fn create_invite_from_code( &self, code: &str, email_address: &str, device_id: Option<&str>, ) -> Result { - let mut tx = self.pool.begin().await?; - - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } + test_support!(self, { + let mut tx = self.pool.begin().await?; - let row: Option<(UserId, i32)> = sqlx::query_as( - " - SELECT id, invite_count - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; - - let (inviter_id, invite_count) = match row { - Some(row) => row, - None => Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))?, - }; + let existing_user: Option = sqlx::query_scalar( + " + SELECT id + FROM users + WHERE email_address = $1 + ", + ) + .bind(email_address) + .fetch_optional(&mut tx) + .await?; + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } - if invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } + let row: Option<(UserId, i32)> = sqlx::query_as( + " + SELECT id, invite_count + FROM users + WHERE invite_code = $1 + ", + ) + .bind(code) + .fetch_optional(&mut tx) + .await?; - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id + let (inviter_id, invite_count) = match row { + Some(row) => row, + None => Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))?, + }; + + if invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + let email_confirmation_code: String = sqlx::query_scalar( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + inviting_user_id, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + device_id + ) + VALUES + ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) + ON CONFLICT (email_address) + DO UPDATE SET + inviting_user_id = excluded.inviting_user_id + RETURNING email_confirmation_code + ", ) - VALUES - ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; - - tx.commit().await?; - - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, - }) - } + .bind(&email_address) + .bind(&random_email_confirmation_code()) + .bind(&inviter_id) + .bind(&device_id) + .fetch_one(&mut tx) + .await?; - // projects + tx.commit().await?; - async fn register_project(&self, host_user_id: UserId) -> Result { - Ok(sqlx::query_scalar( - " - INSERT INTO projects(host_user_id) - VALUES ($1) - RETURNING id - ", - ) - .bind(host_user_id) - .fetch_one(&self.pool) - .await - .map(ProjectId)?) + Ok(Invite { + email_address: email_address.into(), + email_confirmation_code, + }) + }) } - async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { - sqlx::query( - " - UPDATE projects - SET unregistered = 't' - WHERE id = $1 - ", - ) - .bind(project_id) - .execute(&self.pool) - .await?; - Ok(()) + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + test_support!(self, { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + sqlx::query( + " + UPDATE signups + SET email_confirmation_sent = TRUE + WHERE email_address = ANY ($1) + ", + ) + .bind(&emails) + .execute(&self.pool) + .await?; + Ok(()) + }) } +} - async fn update_worktree_extensions( +impl Db +where + D: sqlx::Database + sqlx::migrate::MigrateDatabase, + D::Connection: sqlx::migrate::Migrate, + for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, + for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, + for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, + D::QueryResult: RowsAffected, + String: sqlx::Type, + i32: sqlx::Type, + i64: sqlx::Type, + bool: sqlx::Type, + str: sqlx::Type, + Uuid: sqlx::Type, + sqlx::types::Json: sqlx::Type, + OffsetDateTime: sqlx::Type, + PrimitiveDateTime: sqlx::Type, + usize: sqlx::ColumnIndex, + for<'a> &'a str: sqlx::ColumnIndex, + for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>, +{ + pub async fn migrate( &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()> { - if extensions.is_empty() { - return Ok(()); - } + migrations_path: &Path, + ignore_checksum_mismatch: bool, + ) -> anyhow::Result> { + let migrations = MigrationSource::resolve(migrations_path) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - let mut query = QueryBuilder::new( - "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)", - ); - query.push_values(extensions, |mut query, (extension, count)| { - query - .push_bind(project_id) - .push_bind(worktree_id as i32) - .push_bind(extension) - .push_bind(count as i32); - }); - query.push( - " - ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET - count = excluded.count - ", - ); - query.build().execute(&self.pool).await?; - - Ok(()) - } + let mut conn = self.pool.acquire().await?; - async fn get_project_extensions( - &self, - project_id: ProjectId, - ) -> Result>> { - #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] - struct WorktreeExtension { - worktree_id: i32, - extension: String, - count: i32, + conn.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = conn + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); + + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch + { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = conn.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } } - let query = " - SELECT worktree_id, extension, count - FROM worktree_extensions - WHERE project_id = $1 - "; - let counts = sqlx::query_as::<_, WorktreeExtension>(query) - .bind(&project_id) - .fetch_all(&self.pool) - .await?; + Ok(new_migrations) + } - let mut extension_counts = HashMap::default(); - for count in counts { - extension_counts - .entry(count.worktree_id as u64) - .or_insert_with(HashMap::default) - .insert(count.extension, count.count as usize); + pub fn fuzzy_like_string(string: &str) -> String { + let mut result = String::with_capacity(string.len() * 2 + 1); + for c in string.chars() { + if c.is_alphanumeric() { + result.push('%'); + result.push(c); + } } - Ok(extension_counts) + result.push('%'); + result } - async fn record_user_activity( - &self, - time_period: Range, - projects: &[(UserId, ProjectId)], - ) -> Result<()> { - let query = " - INSERT INTO project_activity_periods - (ended_at, duration_millis, user_id, project_id) - VALUES - ($1, $2, $3, $4); - "; - - let mut tx = self.pool.begin().await?; - let duration_millis = - ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32; - for (user_id, project_id) in projects { - sqlx::query(query) - .bind(time_period.end) - .bind(duration_millis) - .bind(user_id) - .bind(project_id) - .execute(&mut tx) - .await?; - } - tx.commit().await?; + // users - Ok(()) + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + test_support!(self, { + let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; + Ok(sqlx::query_as(query) + .bind(limit as i32) + .bind((page * limit) as i32) + .fetch_all(&self.pool) + .await?) + }) } - async fn get_active_user_count( + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + test_support!(self, { + let query = " + SELECT users.* + FROM users + WHERE id = $1 + LIMIT 1 + "; + Ok(sqlx::query_as(query) + .bind(&id) + .fetch_optional(&self.pool) + .await?) + }) + } + + pub async fn get_users_with_no_invites( &self, - time_period: Range, - min_duration: Duration, - only_collaborative: bool, - ) -> Result { - let mut with_clause = String::new(); - with_clause.push_str("WITH\n"); - with_clause.push_str( - " - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - ", - ); - with_clause.push_str( - " - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id - ), - ", - ); - - if only_collaborative { - with_clause.push_str( - " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations, project_collaborators - WHERE - project_durations.project_id = project_collaborators.project_id AND - max_collaborators > 1 - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) - ", - ); - } else { - with_clause.push_str( + invited_by_another_user: bool, + ) -> Result> { + test_support!(self, { + let query = format!( " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) + SELECT users.* + FROM users + WHERE invite_count = 0 + AND inviter_id IS{} NULL ", + if invited_by_another_user { " NOT" } else { "" } ); - } - let query = format!( - " - {with_clause} - SELECT count(user_durations.user_id) - FROM user_durations - WHERE user_durations.total_duration >= $3 - " - ); - - let count: i64 = sqlx::query_scalar(&query) - .bind(time_period.start) - .bind(time_period.end) - .bind(min_duration.as_millis() as i64) - .fetch_one(&self.pool) - .await?; - Ok(count as usize) + Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + }) } - async fn get_top_users_activity_summary( + pub async fn get_user_by_github_account( &self, - time_period: Range, - max_user_count: usize, - ) -> Result> { - let query = " - WITH - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ), - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id + github_login: &str, + github_user_id: Option, + ) -> Result> { + test_support!(self, { + if let Some(github_user_id) = github_user_id { + let mut user = sqlx::query_as::<_, User>( + " + UPDATE users + SET github_login = $1 + WHERE github_user_id = $2 + RETURNING * + ", ) - SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators - FROM user_durations, project_durations, project_collaborators, users - WHERE - user_durations.user_id = project_durations.user_id AND - user_durations.user_id = users.id AND - project_durations.project_id = project_collaborators.project_id - ORDER BY total_duration DESC, user_id ASC, project_id ASC - "; - - let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query) - .bind(time_period.start) - .bind(time_period.end) - .bind(max_user_count as i32) - .fetch(&self.pool); - - let mut result = Vec::::new(); - while let Some(row) = rows.next().await { - let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?; - let project_id = project_id; - let duration = Duration::from_millis(duration_millis as u64); - let project_activity = ProjectActivitySummary { - id: project_id, - duration, - max_collaborators: project_collaborators as usize, - }; - if let Some(last_summary) = result.last_mut() { - if last_summary.id == user_id { - last_summary.project_activity.push(project_activity); - continue; - } - } - result.push(UserActivitySummary { - id: user_id, - project_activity: vec![project_activity], - github_login, - }); - } - - Ok(result) - } + .bind(github_login) + .bind(github_user_id) + .fetch_optional(&self.pool) + .await?; - async fn get_user_activity_timeline( - &self, - time_period: Range, - user_id: UserId, - ) -> Result> { - const COALESCE_THRESHOLD: Duration = Duration::from_secs(30); - - let query = " - SELECT - project_activity_periods.ended_at, - project_activity_periods.duration_millis, - project_activity_periods.project_id, - worktree_extensions.extension, - worktree_extensions.count - FROM project_activity_periods - LEFT OUTER JOIN - worktree_extensions - ON - project_activity_periods.project_id = worktree_extensions.project_id - WHERE - project_activity_periods.user_id = $1 AND - $2 < project_activity_periods.ended_at AND - project_activity_periods.ended_at <= $3 - ORDER BY project_activity_periods.id ASC - "; - - let mut rows = sqlx::query_as::< - _, - ( - PrimitiveDateTime, - i32, - ProjectId, - Option, - Option, - ), - >(query) - .bind(user_id) - .bind(time_period.start) - .bind(time_period.end) - .fetch(&self.pool); - - let mut time_periods: HashMap> = Default::default(); - while let Some(row) = rows.next().await { - let (ended_at, duration_millis, project_id, extension, extension_count) = row?; - let ended_at = ended_at.assume_utc(); - let duration = Duration::from_millis(duration_millis as u64); - let started_at = ended_at - duration; - let project_time_periods = time_periods.entry(project_id).or_default(); - - if let Some(prev_duration) = project_time_periods.last_mut() { - if started_at <= prev_duration.end + COALESCE_THRESHOLD - && ended_at >= prev_duration.start - { - prev_duration.end = cmp::max(prev_duration.end, ended_at); - } else { - project_time_periods.push(UserActivityPeriod { - project_id, - start: started_at, - end: ended_at, - extensions: Default::default(), - }); + if user.is_none() { + user = sqlx::query_as::<_, User>( + " + UPDATE users + SET github_user_id = $1 + WHERE github_login = $2 + RETURNING * + ", + ) + .bind(github_user_id) + .bind(github_login) + .fetch_optional(&self.pool) + .await?; } + + Ok(user) } else { - project_time_periods.push(UserActivityPeriod { - project_id, - start: started_at, - end: ended_at, - extensions: Default::default(), - }); + let user = sqlx::query_as( + " + SELECT * FROM users + WHERE github_login = $1 + LIMIT 1 + ", + ) + .bind(github_login) + .fetch_optional(&self.pool) + .await?; + Ok(user) } + }) + } - if let Some((extension, extension_count)) = extension.zip(extension_count) { - project_time_periods - .last_mut() - .unwrap() - .extensions - .insert(extension, extension_count as usize); - } - } + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + test_support!(self, { + let query = "UPDATE users SET admin = $1 WHERE id = $2"; + Ok(sqlx::query(query) + .bind(is_admin) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) + } - let mut durations = time_periods.into_values().flatten().collect::>(); - durations.sort_unstable_by_key(|duration| duration.start); - Ok(durations) + pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { + test_support!(self, { + let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; + Ok(sqlx::query(query) + .bind(connected_once) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) } - // contacts + pub async fn destroy_user(&self, id: UserId) -> Result<()> { + test_support!(self, { + let query = "DELETE FROM access_tokens WHERE user_id = $1;"; + sqlx::query(query) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?; + let query = "DELETE FROM users WHERE id = $1;"; + Ok(sqlx::query(query) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) + } + + // signups - async fn get_contacts(&self, user_id: UserId) -> Result> { - let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify - FROM contacts - WHERE user_id_a = $1 OR user_id_b = $1; - "; + pub async fn get_waitlist_summary(&self) -> Result { + test_support!(self, { + Ok(sqlx::query_as( + " + SELECT + COUNT(*) as count, + COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, + COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, + COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, + COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count + FROM ( + SELECT * + FROM signups + WHERE + NOT email_confirmation_sent + ) AS unsent + ", + ) + .fetch_one(&self.pool) + .await?) + }) + } - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) - .bind(user_id) - .fetch(&self.pool); + pub async fn get_unsent_invites(&self, count: usize) -> Result> { + test_support!(self, { + Ok(sqlx::query_as( + " + SELECT + email_address, email_confirmation_code + FROM signups + WHERE + NOT email_confirmation_sent AND + (platform_mac OR platform_unknown) + LIMIT $1 + ", + ) + .bind(count as i32) + .fetch_all(&self.pool) + .await?) + }) + } - let mut contacts = Vec::new(); - while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; + // invite codes - if user_id_a == user_id { - if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_b, - should_notify: should_notify && a_to_b, - }); - } else if a_to_b { - contacts.push(Contact::Outgoing { user_id: user_id_b }) - } else { - contacts.push(Contact::Incoming { - user_id: user_id_b, - should_notify, - }); - } - } else if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_a, - should_notify: should_notify && !a_to_b, - }); - } else if a_to_b { - contacts.push(Contact::Incoming { - user_id: user_id_a, - should_notify, - }); - } else { - contacts.push(Contact::Outgoing { user_id: user_id_a }); + pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + test_support!(self, { + let mut tx = self.pool.begin().await?; + if count > 0 { + sqlx::query( + " + UPDATE users + SET invite_code = $1 + WHERE id = $2 AND invite_code IS NULL + ", + ) + .bind(random_invite_code()) + .bind(id) + .execute(&mut tx) + .await?; } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(contacts) + sqlx::query( + " + UPDATE users + SET invite_count = $1 + WHERE id = $2 + ", + ) + .bind(count as i32) + .bind(id) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(()) + }) } - async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + test_support!(self, { + let result: Option<(String, i32)> = sqlx::query_as( + " + SELECT invite_code, invite_count + FROM users + WHERE id = $1 AND invite_code IS NOT NULL + ", + ) + .bind(id) + .fetch_optional(&self.pool) + .await?; + if let Some((code, count)) = result { + Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) + } else { + Ok(None) + } + }) + } - let query = " - SELECT 1 FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't' - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(id_a.0) - .bind(id_b.0) + pub async fn get_user_for_invite_code(&self, code: &str) -> Result { + test_support!(self, { + sqlx::query_as( + " + SELECT * + FROM users + WHERE invite_code = $1 + ", + ) + .bind(code) .fetch_optional(&self.pool) .await? - .is_some()) + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) + }) } - async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) - VALUES ($1, $2, $3, 'f', 't') - ON CONFLICT (user_id_a, user_id_b) DO UPDATE - SET - accepted = 't', - should_notify = 'f' - WHERE - NOT contacts.accepted AND - ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR - (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await?; + // projects - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } + /// Registers a new project for the given user. + pub async fn register_project(&self, host_user_id: UserId) -> Result { + test_support!(self, { + Ok(sqlx::query_scalar( + " + INSERT INTO projects(host_user_id) + VALUES ($1) + RETURNING id + ", + ) + .bind(host_user_id) + .fetch_one(&self.pool) + .await + .map(ProjectId)?) + }) } - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2; - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) + /// Unregisters a project for the given project id. + pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { + test_support!(self, { + sqlx::query( + " + UPDATE projects + SET unregistered = TRUE + WHERE id = $1 + ", + ) + .bind(project_id) .execute(&self.pool) .await?; - - if result.rows_affected() == 1 { Ok(()) - } else { - Err(anyhow!("no such contact"))? - } + }) } - async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; + // contacts - let query = " - UPDATE contacts - SET should_notify = 'f' - WHERE - user_id_a = $1 AND user_id_b = $2 AND - ( - (a_to_b = $3 AND accepted) OR - (a_to_b != $3 AND NOT accepted) - ); - "; - - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await?; + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + test_support!(self, { + let query = " + SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify + FROM contacts + WHERE user_id_a = $1 OR user_id_b = $1; + "; - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))?; - } + let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) + .bind(user_id) + .fetch(&self.pool); + + let mut contacts = Vec::new(); + while let Some(row) = rows.next().await { + let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; + + if user_id_a == user_id { + if accepted { + contacts.push(Contact::Accepted { + user_id: user_id_b, + should_notify: should_notify && a_to_b, + }); + } else if a_to_b { + contacts.push(Contact::Outgoing { user_id: user_id_b }) + } else { + contacts.push(Contact::Incoming { + user_id: user_id_b, + should_notify, + }); + } + } else if accepted { + contacts.push(Contact::Accepted { + user_id: user_id_a, + should_notify: should_notify && !a_to_b, + }); + } else if a_to_b { + contacts.push(Contact::Incoming { + user_id: user_id_a, + should_notify, + }); + } else { + contacts.push(Contact::Outgoing { user_id: user_id_a }); + } + } + + contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(()) + Ok(contacts) + }) } - async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let result = if accept { + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + test_support!(self, { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + let query = " - UPDATE contacts - SET accepted = 't', should_notify = 't' - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + SELECT 1 FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE + LIMIT 1 "; - sqlx::query(query) + Ok(sqlx::query_scalar::<_, i32>(query) .bind(id_a.0) .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) + .fetch_optional(&self.pool) .await? - } else { + .is_some()) + }) + } + + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + test_support!(self, { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; + INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) + VALUES ($1, $2, $3, FALSE, TRUE) + ON CONFLICT (user_id_a, user_id_b) DO UPDATE + SET + accepted = TRUE, + should_notify = FALSE + WHERE + NOT contacts.accepted AND + ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR + (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); "; - sqlx::query(query) + let result = sqlx::query(query) .bind(id_a.0) .bind(id_b.0) .bind(a_to_b) .execute(&self.pool) - .await? - }; - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } + .await?; + + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } + }) } - // access tokens + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + test_support!(self, { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + let query = " + DELETE FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2; + "; + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .execute(&self.pool) + .await?; + + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } + }) + } - async fn create_access_token_hash( + pub async fn dismiss_contact_notification( &self, user_id: UserId, - access_token_hash: &str, - max_access_token_count: usize, + contact_user_id: UserId, ) -> Result<()> { - let insert_query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2); - "; - let cleanup_query = " - DELETE FROM access_tokens - WHERE id IN ( - SELECT id from access_tokens - WHERE user_id = $1 - ORDER BY id DESC - OFFSET $3 - ) - "; - - let mut tx = self.pool.begin().await?; - sqlx::query(insert_query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&mut tx) - .await?; - sqlx::query(cleanup_query) - .bind(user_id.0) - .bind(access_token_hash) - .bind(max_access_token_count as i32) - .execute(&mut tx) - .await?; - Ok(tx.commit().await?) - } - - async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - let query = " - SELECT hash - FROM access_tokens - WHERE user_id = $1 - ORDER BY id DESC - "; - Ok(sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&self.pool) - .await?) - } - - // orgs - - #[allow(unused)] // Help rust-analyzer - #[cfg(any(test, feature = "seed-support"))] - async fn find_org_by_slug(&self, slug: &str) -> Result> { - let query = " - SELECT * - FROM orgs - WHERE slug = $1 - "; - Ok(sqlx::query_as(query) - .bind(slug) - .fetch_optional(&self.pool) - .await?) - } + test_support!(self, { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; - #[cfg(any(test, feature = "seed-support"))] - async fn create_org(&self, name: &str, slug: &str) -> Result { - let query = " - INSERT INTO orgs (name, slug) - VALUES ($1, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(name) - .bind(slug) - .fetch_one(&self.pool) - .await - .map(OrgId)?) - } + let query = " + UPDATE contacts + SET should_notify = FALSE + WHERE + user_id_a = $1 AND user_id_b = $2 AND + ( + (a_to_b = $3 AND accepted) OR + (a_to_b != $3 AND NOT accepted) + ); + "; - #[cfg(any(test, feature = "seed-support"))] - async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> { - let query = " - INSERT INTO org_memberships (org_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(org_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) - } + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; - // channels - - #[cfg(any(test, feature = "seed-support"))] - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - let query = " - INSERT INTO channels (owner_id, owner_is_user, name) - VALUES ($1, false, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(org_id.0) - .bind(name) - .fetch_one(&self.pool) - .await - .map(ChannelId)?) - } + if result.rows_affected() == 0 { + Err(anyhow!("no such contact request"))?; + } - #[allow(unused)] // Help rust-analyzer - #[cfg(any(test, feature = "seed-support"))] - async fn get_org_channels(&self, org_id: OrgId) -> Result> { - let query = " - SELECT * - FROM channels - WHERE - channels.owner_is_user = false AND - channels.owner_id = $1 - "; - Ok(sqlx::query_as(query) - .bind(org_id.0) - .fetch_all(&self.pool) - .await?) + Ok(()) + }) } - async fn get_accessible_channels(&self, user_id: UserId) -> Result> { - let query = " - SELECT - channels.* - FROM - channel_memberships, channels - WHERE - channel_memberships.user_id = $1 AND - channel_memberships.channel_id = channels.id - "; - Ok(sqlx::query_as(query) - .bind(user_id.0) - .fetch_all(&self.pool) - .await?) + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + test_support!(self, { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let result = if accept { + let query = " + UPDATE contacts + SET accepted = TRUE, should_notify = TRUE + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + } else { + let query = " + DELETE FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + }; + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) } - async fn can_user_access_channel( - &self, - user_id: UserId, - channel_id: ChannelId, - ) -> Result { - let query = " - SELECT id - FROM channel_memberships - WHERE user_id = $1 AND channel_id = $2 - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(user_id.0) - .bind(channel_id.0) - .fetch_optional(&self.pool) - .await - .map(|e| e.is_some())?) - } + // access tokens - #[cfg(any(test, feature = "seed-support"))] - async fn add_channel_member( + pub async fn create_access_token_hash( &self, - channel_id: ChannelId, user_id: UserId, - is_admin: bool, + access_token_hash: &str, + max_access_token_count: usize, ) -> Result<()> { - let query = " - INSERT INTO channel_memberships (channel_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(channel_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) - } - - // messages + test_support!(self, { + let insert_query = " + INSERT INTO access_tokens (user_id, hash) + VALUES ($1, $2); + "; + let cleanup_query = " + DELETE FROM access_tokens + WHERE id IN ( + SELECT id from access_tokens + WHERE user_id = $1 + ORDER BY id DESC + LIMIT 10000 + OFFSET $3 + ) + "; - async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result { - let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(channel_id.0) - .bind(sender_id.0) - .bind(body) - .bind(timestamp) - .bind(Uuid::from_u128(nonce)) - .fetch_one(&self.pool) - .await - .map(MessageId)?) + let mut tx = self.pool.begin().await?; + sqlx::query(insert_query) + .bind(user_id.0) + .bind(access_token_hash) + .execute(&mut tx) + .await?; + sqlx::query(cleanup_query) + .bind(user_id.0) + .bind(access_token_hash) + .bind(max_access_token_count as i32) + .execute(&mut tx) + .await?; + Ok(tx.commit().await?) + }) } - async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result> { - let query = r#" - SELECT * FROM ( - SELECT - id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce - FROM - channel_messages - WHERE - channel_id = $1 AND - id < $2 + pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + test_support!(self, { + let query = " + SELECT hash + FROM access_tokens + WHERE user_id = $1 ORDER BY id DESC - LIMIT $3 - ) as recent_messages - ORDER BY id ASC - "#; - Ok(sqlx::query_as(query) - .bind(channel_id.0) - .bind(before_id.unwrap_or(MessageId::MAX)) - .bind(count as i64) - .fetch_all(&self.pool) - .await?) - } - - #[cfg(test)] - async fn teardown(&self, url: &str) { - use util::ResultExt; - - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); - "; - sqlx::query(query).execute(&self.pool).await.log_err(); - self.pool.close().await; - ::drop_database(url) - .await - .log_err(); - } - - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb> { - None + "; + Ok(sqlx::query_scalar(query) + .bind(user_id.0) + .fetch_all(&self.pool) + .await?) + }) } } @@ -1671,58 +1254,6 @@ pub struct Project { pub unregistered: bool, } -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct UserActivitySummary { - pub id: UserId, - pub github_login: String, - pub project_activity: Vec, -} - -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct ProjectActivitySummary { - pub id: ProjectId, - pub duration: Duration, - pub max_collaborators: usize, -} - -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct UserActivityPeriod { - pub project_id: ProjectId, - #[serde(with = "time::serde::iso8601")] - pub start: OffsetDateTime, - #[serde(with = "time::serde::iso8601")] - pub end: OffsetDateTime, - pub extensions: HashMap, -} - -id_type!(OrgId); -#[derive(FromRow)] -pub struct Org { - pub id: OrgId, - pub name: String, - pub slug: String, -} - -id_type!(ChannelId); -#[derive(Clone, Debug, FromRow, Serialize)] -pub struct Channel { - pub id: ChannelId, - pub name: String, - pub owner_id: i32, - pub owner_is_user: bool, -} - -id_type!(MessageId); -#[derive(Clone, Debug, FromRow)] -pub struct ChannelMessage { - pub id: MessageId, - pub channel_id: ChannelId, - pub sender_id: UserId, - pub body: String, - pub sent_at: OffsetDateTime, - pub nonce: Uuid, -} - #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { Accepted { @@ -1814,706 +1345,101 @@ pub use test::*; #[cfg(test)] mod test { use super::*; - use anyhow::anyhow; - use collections::BTreeMap; use gpui::executor::Background; use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; - use sqlx::{migrate::MigrateDatabase, Postgres}; + use sqlx::migrate::MigrateDatabase; use std::sync::Arc; - use util::post_inc; - - pub struct FakeDb { - background: Arc, - pub users: Mutex>, - pub projects: Mutex>, - pub worktree_extensions: Mutex>, - pub orgs: Mutex>, - pub org_memberships: Mutex>, - pub channels: Mutex>, - pub channel_memberships: Mutex>, - pub channel_messages: Mutex>, - pub contacts: Mutex>, - next_channel_message_id: Mutex, - next_user_id: Mutex, - next_org_id: Mutex, - next_channel_id: Mutex, - next_project_id: Mutex, - } - #[derive(Debug)] - pub struct FakeContact { - pub requester_id: UserId, - pub responder_id: UserId, - pub accepted: bool, - pub should_notify: bool, + pub struct SqliteTestDb { + pub db: Option>>, + pub conn: sqlx::sqlite::SqliteConnection, } - impl FakeDb { - pub fn new(background: Arc) -> Self { - Self { - background, - users: Default::default(), - next_user_id: Mutex::new(0), - projects: Default::default(), - worktree_extensions: Default::default(), - next_project_id: Mutex::new(1), - orgs: Default::default(), - next_org_id: Mutex::new(1), - org_memberships: Default::default(), - channels: Default::default(), - next_channel_id: Mutex::new(1), - channel_memberships: Default::default(), - channel_messages: Default::default(), - next_channel_message_id: Mutex::new(1), - contacts: Default::default(), - } - } + pub struct PostgresTestDb { + pub db: Option>>, + pub url: String, } - #[async_trait] - impl Db for FakeDb { - async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.background.simulate_random_delay().await; - - let mut users = self.users.lock(); - let user_id = if let Some(user) = users - .values() - .find(|user| user.github_login == params.github_login) - { - user.id - } else { - let id = post_inc(&mut *self.next_user_id.lock()); - let user_id = UserId(id); - users.insert( - user_id, - User { - id: user_id, - github_login: params.github_login, - github_user_id: Some(params.github_user_id), - email_address: Some(email_address.to_string()), - admin, - invite_code: None, - invite_count: 0, - connected_once: false, - }, - ); - user_id - }; - Ok(NewUserResult { - user_id, - metrics_id: "the-metrics-id".to_string(), - inviting_user_id: None, - signup_device_id: None, - }) - } - - async fn get_all_users(&self, _page: u32, _limit: u32) -> Result> { - unimplemented!() - } - - async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result> { - unimplemented!() - } - - async fn get_user_by_id(&self, id: UserId) -> Result> { - self.background.simulate_random_delay().await; - Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next()) - } - - async fn get_user_metrics_id(&self, _id: UserId) -> Result { - Ok("the-metrics-id".to_string()) - } - - async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.background.simulate_random_delay().await; - let users = self.users.lock(); - Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect()) - } - - async fn get_users_with_no_invites(&self, _: bool) -> Result> { - unimplemented!() - } - - async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - self.background.simulate_random_delay().await; - if let Some(github_user_id) = github_user_id { - for user in self.users.lock().values_mut() { - if user.github_user_id == Some(github_user_id) { - user.github_login = github_login.into(); - return Ok(Some(user.clone())); - } - if user.github_login == github_login { - user.github_user_id = Some(github_user_id); - return Ok(Some(user.clone())); - } - } - Ok(None) - } else { - Ok(self - .users - .lock() - .values() - .find(|user| user.github_login == github_login) - .cloned()) - } - } - - async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> { - unimplemented!() - } - - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.background.simulate_random_delay().await; - let mut users = self.users.lock(); - let mut user = users - .get_mut(&id) - .ok_or_else(|| anyhow!("user not found"))?; - user.connected_once = connected_once; - Ok(()) - } - - async fn destroy_user(&self, _id: UserId) -> Result<()> { - unimplemented!() - } - - // signups - - async fn create_signup(&self, _signup: Signup) -> Result<()> { - unimplemented!() - } - - async fn get_waitlist_summary(&self) -> Result { - unimplemented!() - } - - async fn get_unsent_invites(&self, _count: usize) -> Result> { - unimplemented!() - } - - async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { - unimplemented!() - } - - async fn create_user_from_invite( - &self, - _invite: &Invite, - _user: NewUserParams, - ) -> Result> { - unimplemented!() - } - - // invite codes - - async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> { - unimplemented!() - } - - async fn get_invite_code_for_user(&self, _id: UserId) -> Result> { - self.background.simulate_random_delay().await; - Ok(None) - } - - async fn get_user_for_invite_code(&self, _code: &str) -> Result { - unimplemented!() - } - - async fn create_invite_from_code( - &self, - _code: &str, - _email_address: &str, - _device_id: Option<&str>, - ) -> Result { - unimplemented!() - } - - // projects - - async fn register_project(&self, host_user_id: UserId) -> Result { - self.background.simulate_random_delay().await; - if !self.users.lock().contains_key(&host_user_id) { - Err(anyhow!("no such user"))?; - } - - let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock())); - self.projects.lock().insert( - project_id, - Project { - id: project_id, - host_user_id, - unregistered: false, - }, - ); - Ok(project_id) - } - - async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { - self.background.simulate_random_delay().await; - self.projects - .lock() - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))? - .unregistered = true; - Ok(()) - } - - async fn update_worktree_extensions( - &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.projects.lock().contains_key(&project_id) { - Err(anyhow!("no such project"))?; - } - - for (extension, count) in extensions { - self.worktree_extensions - .lock() - .insert((project_id, worktree_id, extension), count); - } - - Ok(()) - } - - async fn get_project_extensions( - &self, - _project_id: ProjectId, - ) -> Result>> { - unimplemented!() - } - - async fn record_user_activity( - &self, - _time_period: Range, - _active_projects: &[(UserId, ProjectId)], - ) -> Result<()> { - unimplemented!() - } - - async fn get_active_user_count( - &self, - _time_period: Range, - _min_duration: Duration, - _only_collaborative: bool, - ) -> Result { - unimplemented!() - } - - async fn get_top_users_activity_summary( - &self, - _time_period: Range, - _limit: usize, - ) -> Result> { - unimplemented!() - } - - async fn get_user_activity_timeline( - &self, - _time_period: Range, - _user_id: UserId, - ) -> Result> { - unimplemented!() - } - - // contacts - - async fn get_contacts(&self, id: UserId) -> Result> { - self.background.simulate_random_delay().await; - let mut contacts = Vec::new(); - - for contact in self.contacts.lock().iter() { - if contact.requester_id == id { - if contact.accepted { - contacts.push(Contact::Accepted { - user_id: contact.responder_id, - should_notify: contact.should_notify, - }); - } else { - contacts.push(Contact::Outgoing { - user_id: contact.responder_id, - }); - } - } else if contact.responder_id == id { - if contact.accepted { - contacts.push(Contact::Accepted { - user_id: contact.requester_id, - should_notify: false, - }); - } else { - contacts.push(Contact::Incoming { - user_id: contact.requester_id, - should_notify: contact.should_notify, - }); - } - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(contacts) - } - - async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result { - self.background.simulate_random_delay().await; - Ok(self.contacts.lock().iter().any(|contact| { - contact.accepted - && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b) - || (contact.requester_id == user_id_b && contact.responder_id == user_id_a)) - })) - } - - async fn send_contact_request( - &self, - requester_id: UserId, - responder_id: UserId, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for contact in contacts.iter_mut() { - if contact.requester_id == requester_id && contact.responder_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already exists"))?; - } else { - Err(anyhow!("contact already requested"))?; - } - } - if contact.responder_id == requester_id && contact.requester_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already exists"))?; - } else { - contact.accepted = true; - contact.should_notify = false; - return Ok(()); - } - } - } - contacts.push(FakeContact { - requester_id, - responder_id, - accepted: false, - should_notify: true, - }); - Ok(()) - } + impl SqliteTestDb { + pub fn new(background: Arc) -> Self { + let mut rng = StdRng::from_entropy(); + let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.background.simulate_random_delay().await; - self.contacts.lock().retain(|contact| { - !(contact.requester_id == requester_id && contact.responder_id == responder_id) + let (mut db, conn) = runtime.block_on(async { + let db = Db::::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); + db.migrate(migrations_path.as_ref(), false).await.unwrap(); + let conn = db.pool.acquire().await.unwrap().detach(); + (db, conn) }); - Ok(()) - } - - async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for contact in contacts.iter_mut() { - if contact.requester_id == contact_user_id - && contact.responder_id == user_id - && !contact.accepted - { - contact.should_notify = false; - return Ok(()); - } - if contact.requester_id == user_id - && contact.responder_id == contact_user_id - && contact.accepted - { - contact.should_notify = false; - return Ok(()); - } - } - Err(anyhow!("no such notification"))? - } - - async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for (ix, contact) in contacts.iter_mut().enumerate() { - if contact.requester_id == requester_id && contact.responder_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already confirmed"))?; - } - if accept { - contact.accepted = true; - contact.should_notify = true; - } else { - contacts.remove(ix); - } - return Ok(()); - } - } - Err(anyhow!("no such contact request"))? - } - - async fn create_access_token_hash( - &self, - _user_id: UserId, - _access_token_hash: &str, - _max_access_token_count: usize, - ) -> Result<()> { - unimplemented!() - } - - async fn get_access_token_hashes(&self, _user_id: UserId) -> Result> { - unimplemented!() - } - - async fn find_org_by_slug(&self, _slug: &str) -> Result> { - unimplemented!() - } - - async fn create_org(&self, name: &str, slug: &str) -> Result { - self.background.simulate_random_delay().await; - let mut orgs = self.orgs.lock(); - if orgs.values().any(|org| org.slug == slug) { - Err(anyhow!("org already exists"))? - } else { - let org_id = OrgId(post_inc(&mut *self.next_org_id.lock())); - orgs.insert( - org_id, - Org { - id: org_id, - name: name.to_string(), - slug: slug.to_string(), - }, - ); - Ok(org_id) - } - } - - async fn add_org_member( - &self, - org_id: OrgId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.orgs.lock().contains_key(&org_id) { - Err(anyhow!("org does not exist"))?; - } - if !self.users.lock().contains_key(&user_id) { - Err(anyhow!("user does not exist"))?; - } - - self.org_memberships - .lock() - .entry((org_id, user_id)) - .or_insert(is_admin); - Ok(()) - } - - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - self.background.simulate_random_delay().await; - if !self.orgs.lock().contains_key(&org_id) { - Err(anyhow!("org does not exist"))?; - } - - let mut channels = self.channels.lock(); - let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock())); - channels.insert( - channel_id, - Channel { - id: channel_id, - name: name.to_string(), - owner_id: org_id.0, - owner_is_user: false, - }, - ); - Ok(channel_id) - } - - async fn get_org_channels(&self, org_id: OrgId) -> Result> { - self.background.simulate_random_delay().await; - Ok(self - .channels - .lock() - .values() - .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0) - .cloned() - .collect()) - } - - async fn get_accessible_channels(&self, user_id: UserId) -> Result> { - self.background.simulate_random_delay().await; - let channels = self.channels.lock(); - let memberships = self.channel_memberships.lock(); - Ok(channels - .values() - .filter(|channel| memberships.contains_key(&(channel.id, user_id))) - .cloned() - .collect()) - } - - async fn can_user_access_channel( - &self, - user_id: UserId, - channel_id: ChannelId, - ) -> Result { - self.background.simulate_random_delay().await; - Ok(self - .channel_memberships - .lock() - .contains_key(&(channel_id, user_id))) - } - async fn add_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.channels.lock().contains_key(&channel_id) { - Err(anyhow!("channel does not exist"))?; - } - if !self.users.lock().contains_key(&user_id) { - Err(anyhow!("user does not exist"))?; - } - - self.channel_memberships - .lock() - .entry((channel_id, user_id)) - .or_insert(is_admin); - Ok(()) - } - - async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result { - self.background.simulate_random_delay().await; - if !self.channels.lock().contains_key(&channel_id) { - Err(anyhow!("channel does not exist"))?; - } - if !self.users.lock().contains_key(&sender_id) { - Err(anyhow!("user does not exist"))?; - } + db.background = Some(background); + db.runtime = Some(runtime); - let mut messages = self.channel_messages.lock(); - if let Some(message) = messages - .values() - .find(|message| message.nonce.as_u128() == nonce) - { - Ok(message.id) - } else { - let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock())); - messages.insert( - message_id, - ChannelMessage { - id: message_id, - channel_id, - sender_id, - body: body.to_string(), - sent_at: timestamp, - nonce: Uuid::from_u128(nonce), - }, - ); - Ok(message_id) + Self { + db: Some(Arc::new(db)), + conn, } } - async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result> { - self.background.simulate_random_delay().await; - let mut messages = self - .channel_messages - .lock() - .values() - .rev() - .filter(|message| { - message.channel_id == channel_id - && message.id < before_id.unwrap_or(MessageId::MAX) - }) - .take(count) - .cloned() - .collect::>(); - messages.sort_unstable_by_key(|message| message.id); - Ok(messages) - } - - async fn teardown(&self, _: &str) {} - - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb> { - Some(self) + pub fn db(&self) -> &Arc> { + self.db.as_ref().unwrap() } } - pub struct TestDb { - pub db: Option>, - pub url: String, - } - - impl TestDb { - #[allow(clippy::await_holding_lock)] - pub async fn postgres() -> Self { + impl PostgresTestDb { + pub fn new(background: Arc) -> Self { lazy_static! { static ref LOCK: Mutex<()> = Mutex::new(()); } let _guard = LOCK.lock(); let mut rng = StdRng::from_entropy(); - let name = format!("zed-test-{}", rng.gen::()); - let url = format!("postgres://postgres@localhost/{}", name); - Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let db = PostgresDb::new(&url, 5).await.unwrap(); - db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false) - .await + let url = format!( + "postgres://postgres@localhost/zed-test-{}", + rng.gen::() + ); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() .unwrap(); + + let mut db = runtime.block_on(async { + sqlx::Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let db = Db::::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); + db.migrate(Path::new(migrations_path), false).await.unwrap(); + db + }); + + db.background = Some(background); + db.runtime = Some(runtime); + Self { db: Some(Arc::new(db)), url, } } - pub fn fake(background: Arc) -> Self { - Self { - db: Some(Arc::new(FakeDb::new(background))), - url: Default::default(), - } - } - - pub fn db(&self) -> &Arc { + pub fn db(&self) -> &Arc> { self.db.as_ref().unwrap() } } - impl Drop for TestDb { + impl Drop for PostgresTestDb { fn drop(&mut self) { - if let Some(db) = self.db.take() { - futures::executor::block_on(db.teardown(&self.url)); - } + let db = self.db.take().unwrap(); + db.teardown(&self.url); } } } diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index ff5e05dd5dd286c28974e7508197c426ab37fbc9..8eda7d34e298c975e53140c9ce3a7aed1551b706 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -1,17 +1,30 @@ use super::db::*; -use collections::HashMap; use gpui::executor::{Background, Deterministic}; -use std::{sync::Arc, time::Duration}; -use time::OffsetDateTime; +use std::sync::Arc; + +macro_rules! test_both_dbs { + ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { + #[gpui::test] + async fn $postgres_test_name() { + let test_db = PostgresTestDb::new(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } -#[tokio::test(flavor = "multi_thread")] -async fn test_get_users_by_ids() { - for test_db in [ - TestDb::postgres().await, - TestDb::fake(build_background_executor()), - ] { - let db = test_db.db(); + #[gpui::test] + async fn $sqlite_test_name() { + let test_db = SqliteTestDb::new(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } + }; +} +test_both_dbs!( + test_get_users_by_ids_postgres, + test_get_users_by_ids_sqlite, + db, + { let mut user_ids = Vec::new(); for i in 1..=4 { user_ids.push( @@ -68,15 +81,13 @@ async fn test_get_users_by_ids() { ] ); } -} +); -#[tokio::test(flavor = "multi_thread")] -async fn test_get_user_by_github_account() { - for test_db in [ - TestDb::postgres().await, - TestDb::fake(build_background_executor()), - ] { - let db = test_db.db(); +test_both_dbs!( + test_get_user_by_github_account_postgres, + test_get_user_by_github_account_sqlite, + db, + { let user_id1 = db .create_user( "user1@example.com", @@ -128,87 +139,57 @@ async fn test_get_user_by_github_account() { assert_eq!(&user.github_login, "the-new-login2"); assert_eq!(user.github_user_id, Some(102)); } -} +); -#[tokio::test(flavor = "multi_thread")] -async fn test_worktree_extensions() { - let test_db = TestDb::postgres().await; - let db = test_db.db(); +test_both_dbs!( + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite, + db, + { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let project = db.register_project(user).await.unwrap(); + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); - db.update_worktree_extensions(project, 100, Default::default()) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 100, - [("rs".to_string(), 5), ("md".to_string(), 3)] - .into_iter() - .collect(), - ) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 100, - [("rs".to_string(), 6), ("md".to_string(), 5)] - .into_iter() - .collect(), - ) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 101, - [("ts".to_string(), 2), ("md".to_string(), 1)] - .into_iter() - .collect(), - ) - .await - .unwrap(); + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); - assert_eq!( - db.get_project_extensions(project).await.unwrap(), - [ - ( - 100, - [("rs".into(), 6), ("md".into(), 5),] - .into_iter() - .collect::>() - ), - ( - 101, - [("ts".into(), 2), ("md".into(), 1),] - .into_iter() - .collect::>() - ) - ] - .into_iter() - .collect() - ); -} + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); -#[tokio::test(flavor = "multi_thread")] -async fn test_user_activity() { - let test_db = TestDb::postgres().await; - let db = test_db.db(); + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +); +test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { let mut user_ids = Vec::new(); - for i in 0..=2 { + for i in 0..3 { user_ids.push( db.create_user( &format!("user{i}@example.com"), @@ -225,371 +206,198 @@ async fn test_user_activity() { ); } - let project_1 = db.register_project(user_ids[0]).await.unwrap(); - db.update_worktree_extensions( - project_1, - 1, - HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]), - ) - .await - .unwrap(); - let project_2 = db.register_project(user_ids[1]).await.unwrap(); - let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60); - - // User 2 opens a project - let t1 = t0 + Duration::from_secs(10); - db.record_user_activity(t0..t1, &[(user_ids[1], project_2)]) - .await - .unwrap(); - - let t2 = t1 + Duration::from_secs(10); - db.record_user_activity(t1..t2, &[(user_ids[1], project_2)]) - .await - .unwrap(); - - // User 1 joins the project - let t3 = t2 + Duration::from_secs(10); - db.record_user_activity( - t2..t3, - &[(user_ids[1], project_2), (user_ids[0], project_2)], - ) - .await - .unwrap(); + let user_1 = user_ids[0]; + let user_2 = user_ids[1]; + let user_3 = user_ids[2]; - // User 1 opens another project - let t4 = t3 + Duration::from_secs(10); - db.record_user_activity( - t3..t4, - &[ - (user_ids[1], project_2), - (user_ids[0], project_2), - (user_ids[0], project_1), - ], - ) - .await - .unwrap(); + // User starts with no contacts + assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - // User 3 joins that project - let t5 = t4 + Duration::from_secs(10); - db.record_user_activity( - t4..t5, - &[ - (user_ids[1], project_2), - (user_ids[0], project_2), - (user_ids[0], project_1), - (user_ids[2], project_1), - ], - ) - .await - .unwrap(); - - // User 2 leaves - let t6 = t5 + Duration::from_secs(5); - db.record_user_activity( - t5..t6, - &[(user_ids[0], project_1), (user_ids[2], project_1)], - ) - .await - .unwrap(); - - let t7 = t6 + Duration::from_secs(60); - let t8 = t7 + Duration::from_secs(10); - db.record_user_activity(t7..t8, &[(user_ids[0], project_1)]) - .await - .unwrap(); - - assert_eq!( - db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(), - &[ - UserActivitySummary { - id: user_ids[0], - github_login: "user0".to_string(), - project_activity: vec![ - ProjectActivitySummary { - id: project_1, - duration: Duration::from_secs(25), - max_collaborators: 2 - }, - ProjectActivitySummary { - id: project_2, - duration: Duration::from_secs(30), - max_collaborators: 2 - } - ] - }, - UserActivitySummary { - id: user_ids[1], - github_login: "user1".to_string(), - project_activity: vec![ProjectActivitySummary { - id: project_2, - duration: Duration::from_secs(50), - max_collaborators: 2 - }] - }, - UserActivitySummary { - id: user_ids[2], - github_login: "user2".to_string(), - project_activity: vec![ProjectActivitySummary { - id: project_1, - duration: Duration::from_secs(15), - max_collaborators: 2 - }] - }, - ] - ); - - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(56), false) - .await - .unwrap(), - 0 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(56), true) - .await - .unwrap(), - 0 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(54), false) - .await - .unwrap(), - 1 - ); + // User requests a contact. Both users see the pending request. + db.send_contact_request(user_1, user_2).await.unwrap(); + assert!(!db.has_contact(user_1, user_2).await.unwrap()); + assert!(!db.has_contact(user_2, user_1).await.unwrap()); assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(54), true) - .await - .unwrap(), - 1 + db.get_contacts(user_1).await.unwrap(), + &[Contact::Outgoing { user_id: user_2 }], ); assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(30), false) - .await - .unwrap(), - 2 + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: true + }] ); + + // User 2 dismisses the contact request notification without accepting or rejecting. + // We shouldn't notify them again. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap_err(); + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap(); assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(30), true) - .await - .unwrap(), - 2 + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: false + }] ); + + // User can't accept their own contact request + db.respond_to_contact_request(user_1, user_2, true) + .await + .unwrap_err(); + + // User accepts a contact request. Both users see the contact. + db.respond_to_contact_request(user_2, user_1, true) + .await + .unwrap(); assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(10), false) - .await - .unwrap(), - 3 + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true + }], ); + assert!(db.has_contact(user_1, user_2).await.unwrap()); + assert!(db.has_contact(user_2, user_1).await.unwrap()); assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(10), true) - .await - .unwrap(), - 3 + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + }] ); + + // Users cannot re-request existing contacts. + db.send_contact_request(user_1, user_2).await.unwrap_err(); + db.send_contact_request(user_2, user_1).await.unwrap_err(); + + // Users can't dismiss notifications of them accepting other users' requests. + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap_err(); assert_eq!( - db.get_active_user_count(t0..t1, Duration::from_secs(5), false) - .await - .unwrap(), - 1 + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true, + }] ); + + // Users can dismiss notifications of other users accepting their requests. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap(); assert_eq!( - db.get_active_user_count(t0..t1, Duration::from_secs(5), true) - .await - .unwrap(), - 0 + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: false, + }] ); + // Users send each other concurrent contact requests and + // see that they are immediately accepted. + db.send_contact_request(user_1, user_3).await.unwrap(); + db.send_contact_request(user_3, user_1).await.unwrap(); assert_eq!( - db.get_user_activity_timeline(t3..t6, user_ids[0]) - .await - .unwrap(), + db.get_contacts(user_1).await.unwrap(), &[ - UserActivityPeriod { - project_id: project_1, - start: t3, - end: t6, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - UserActivityPeriod { - project_id: project_2, - start: t3, - end: t5, - extensions: Default::default(), + Contact::Accepted { + user_id: user_2, + should_notify: false, }, + Contact::Accepted { + user_id: user_3, + should_notify: false + } ] ); assert_eq!( - db.get_user_activity_timeline(t0..t8, user_ids[0]) - .await - .unwrap(), - &[ - UserActivityPeriod { - project_id: project_2, - start: t2, - end: t5, - extensions: Default::default(), - }, - UserActivityPeriod { - project_id: project_1, - start: t3, - end: t6, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - UserActivityPeriod { - project_id: project_1, - start: t7, - end: t8, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - ] + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false + }], ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_recent_channel_messages() { - for test_db in [ - TestDb::postgres().await, - TestDb::fake(build_background_executor()), - ] { - let db = test_db.db(); - let user = db - .create_user( - "u@example.com", - false, - NewUserParams { - github_login: "u".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap(); - } - - let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); - assert_eq!( - messages.iter().map(|m| &m.body).collect::>(), - ["5", "6", "7", "8", "9"] - ); - - let prev_messages = db - .get_channel_messages(channel, 4, Some(messages[0].id)) - .await - .unwrap(); - assert_eq!( - prev_messages.iter().map(|m| &m.body).collect::>(), - ["1", "2", "3", "4"] - ); - } -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_channel_message_nonces() { - for test_db in [ - TestDb::postgres().await, - TestDb::fake(build_background_executor()), - ] { - let db = test_db.db(); - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); - } -} + // User declines a contact request. Both users see that it is gone. + db.send_contact_request(user_2, user_3).await.unwrap(); + db.respond_to_contact_request(user_3, user_2, false) + .await + .unwrap(); + assert!(!db.has_contact(user_2, user_3).await.unwrap()); + assert!(!db.has_contact(user_3, user_2).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false + }] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false + }], + ); +}); -#[tokio::test(flavor = "multi_thread")] -async fn test_create_access_tokens() { - let test_db = TestDb::postgres().await; - let db = test_db.db(); - let user = db +test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { + let NewUserResult { + user_id: user1, + metrics_id: metrics_id1, + .. + } = db .create_user( - "u1@example.com", + "person1@example.com", false, NewUserParams { - github_login: "u1".into(), - github_user_id: 1, - invite_count: 0, + github_login: "person1".into(), + github_user_id: 101, + invite_count: 5, }, ) .await - .unwrap() - .user_id; - - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] - ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] - ); - - db.create_access_token_hash(user, "h4", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] - ); + .unwrap(); + let NewUserResult { + user_id: user2, + metrics_id: metrics_id2, + .. + } = db + .create_user( + "person2@example.com", + false, + NewUserParams { + github_login: "person2".into(), + github_user_id: 102, + invite_count: 5, + }, + ) + .await + .unwrap(); - db.create_access_token_hash(user, "h5", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] - ); -} + assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); + assert_eq!(metrics_id1.len(), 36); + assert_eq!(metrics_id2.len(), 36); + assert_ne!(metrics_id1, metrics_id2); +}); #[test] fn test_fuzzy_like_string() { - assert_eq!(PostgresDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(PostgresDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(PostgresDb::fuzzy_like_string(" z "), "%z%"); + assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_fuzzy_search_users() { - let test_db = TestDb::postgres().await; + let test_db = PostgresTestDb::new(build_background_executor()); let db = test_db.db(); for (i, github_login) in [ "California", @@ -625,7 +433,7 @@ async fn test_fuzzy_search_users() { &["rhode-island", "colorado", "oregon"], ); - async fn fuzzy_search_user_names(db: &Arc, query: &str) -> Vec { + async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { db.fuzzy_search_users(query, 10) .await .unwrap() @@ -635,178 +443,11 @@ async fn test_fuzzy_search_users() { } } -#[tokio::test(flavor = "multi_thread")] -async fn test_add_contacts() { - for test_db in [ - TestDb::postgres().await, - TestDb::fake(build_background_executor()), - ] { - let db = test_db.db(); - - let mut user_ids = Vec::new(); - for i in 0..3 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id, - ); - } - - let user_1 = user_ids[0]; - let user_2 = user_ids[1]; - let user_3 = user_ids[2]; - - // User starts with no contacts - assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - - // User requests a contact. Both users see the pending request. - db.send_contact_request(user_1, user_2).await.unwrap(); - assert!(!db.has_contact(user_1, user_2).await.unwrap()); - assert!(!db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Outgoing { user_id: user_2 }], - ); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] - ); - - // User 2 dismisses the contact request notification without accepting or rejecting. - // We shouldn't notify them again. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap_err(); - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] - ); - - // User can't accept their own contact request - db.respond_to_contact_request(user_1, user_2, true) - .await - .unwrap_err(); - - // User accepts a contact request. Both users see the contact. - db.respond_to_contact_request(user_2, user_1, true) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true - }], - ); - assert!(db.has_contact(user_1, user_2).await.unwrap()); - assert!(db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - }] - ); - - // Users cannot re-request existing contacts. - db.send_contact_request(user_1, user_2).await.unwrap_err(); - db.send_contact_request(user_2, user_1).await.unwrap_err(); - - // Users can't dismiss notifications of them accepting other users' requests. - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap_err(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - }] - ); - - // Users can dismiss notifications of other users accepting their requests. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: false, - }] - ); - - // Users send each other concurrent contact requests and - // see that they are immediately accepted. - db.send_contact_request(user_1, user_3).await.unwrap(); - db.send_contact_request(user_3, user_1).await.unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[ - Contact::Accepted { - user_id: user_2, - should_notify: false, - }, - Contact::Accepted { - user_id: user_3, - should_notify: false - } - ] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false - }], - ); - - // User declines a contact request. Both users see that it is gone. - db.send_contact_request(user_2, user_3).await.unwrap(); - db.respond_to_contact_request(user_3, user_2, false) - .await - .unwrap(); - assert!(!db.has_contact(user_2, user_3).await.unwrap()); - assert!(!db.has_contact(user_3, user_2).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false - }] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false - }], - ); - } -} - -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_invite_codes() { - let postgres = TestDb::postgres().await; - let db = postgres.db(); + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); + let NewUserResult { user_id: user1, .. } = db .create_user( "user1@example.com", @@ -998,10 +639,10 @@ async fn test_invite_codes() { assert_eq!(invite_count, 1); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_signups() { - let postgres = TestDb::postgres().await; - let db = postgres.db(); + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); // people sign up on the waitlist for i in 0..8 { @@ -1144,51 +785,6 @@ async fn test_signups() { .unwrap_err(); } -#[tokio::test(flavor = "multi_thread")] -async fn test_metrics_id() { - let postgres = TestDb::postgres().await; - let db = postgres.db(); - - let NewUserResult { - user_id: user1, - metrics_id: metrics_id1, - .. - } = db - .create_user( - "person1@example.com", - false, - NewUserParams { - github_login: "person1".into(), - github_user_id: 101, - invite_count: 5, - }, - ) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - metrics_id: metrics_id2, - .. - } = db - .create_user( - "person2@example.com", - false, - NewUserParams { - github_login: "person2".into(), - github_user_id: 102, - invite_count: 5, - }, - ) - .await - .unwrap(); - - assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); - assert_eq!(metrics_id1.len(), 36); - assert_eq!(metrics_id2.len(), 36); - assert_ne!(metrics_id1, metrics_id2); -} - fn build_background_executor() -> Arc { Deterministic::new(0).build_background() } diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 2bf2701f23bf45462e6b8b0e28e4560a72615c1e..a77345270baa8239f5fbfbf087fc9f1c4bc799a6 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,14 +1,14 @@ use crate::{ - db::{NewUserParams, ProjectId, TestDb, UserId}, - rpc::{Executor, Server, Store}, + db::{NewUserParams, ProjectId, SqliteTestDb as TestDb, UserId}, + rpc::{Executor, Server}, AppState, }; use ::rpc::Peer; use anyhow::anyhow; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{ - self, test::FakeHttpClient, Channel, ChannelDetails, ChannelList, Client, Connection, - Credentials, EstablishConnectionError, PeerId, User, UserStore, RECEIVE_TIMEOUT, + self, test::FakeHttpClient, Client, Connection, Credentials, EstablishConnectionError, PeerId, + User, UserStore, RECEIVE_TIMEOUT, }; use collections::{BTreeMap, HashMap, HashSet}; use editor::{ @@ -16,10 +16,7 @@ use editor::{ ToggleCodeActions, Undo, }; use fs::{FakeFs, Fs as _, HomeDir, LineEnding}; -use futures::{ - channel::{mpsc, oneshot}, - Future, StreamExt as _, -}; +use futures::{channel::oneshot, Future, StreamExt as _}; use gpui::{ executor::{self, Deterministic}, geometry::vector::vec2f, @@ -39,7 +36,6 @@ use project::{ use rand::prelude::*; use serde_json::json; use settings::{Formatter, Settings}; -use sqlx::types::time::OffsetDateTime; use std::{ cell::{Cell, RefCell}, env, mem, @@ -73,7 +69,10 @@ async fn test_basic_calls( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; + + let start = std::time::Instant::now(); + let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -259,6 +258,8 @@ async fn test_basic_calls( pending: Default::default() } ); + + eprintln!("finished test {:?}", start.elapsed()); } #[gpui::test(iterations = 10)] @@ -271,7 +272,7 @@ async fn test_room_uniqueness( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let _client_a2 = server.create_client(cx_a2, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; @@ -376,7 +377,7 @@ async fn test_leaving_room_on_disconnection( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -505,7 +506,7 @@ async fn test_calls_on_multiple_connections( cx_b2: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b1 = server.create_client(cx_b1, "user_b").await; let client_b2 = server.create_client(cx_b2, "user_b").await; @@ -654,7 +655,7 @@ async fn test_share_project( ) { deterministic.forbid_parking(); let (_, window_b) = cx_b.add_window(|_| EmptyView); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -791,7 +792,7 @@ async fn test_unshare_project( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -874,7 +875,7 @@ async fn test_host_disconnect( ) { cx_b.update(editor::init); deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -979,7 +980,7 @@ async fn test_active_call_events( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; client_a.fs.insert_tree("/a", json!({})).await; @@ -1068,7 +1069,7 @@ async fn test_room_location( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; client_a.fs.insert_tree("/a", json!({})).await; @@ -1234,7 +1235,7 @@ async fn test_propagate_saves_and_fs_changes( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -1409,7 +1410,7 @@ async fn test_git_diff_base_change( cx_b: &mut TestAppContext, ) { executor.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1661,7 +1662,7 @@ async fn test_fs_operations( cx_b: &mut TestAppContext, ) { executor.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1927,7 +1928,7 @@ async fn test_fs_operations( #[gpui::test(iterations = 10)] async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1981,7 +1982,7 @@ async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut T #[gpui::test(iterations = 10)] async fn test_buffer_reloading(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2040,7 +2041,7 @@ async fn test_editing_while_guest_opens_buffer( cx_b: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2087,7 +2088,7 @@ async fn test_leaving_worktree_while_opening_buffer( cx_b: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2132,7 +2133,7 @@ async fn test_canceling_buffer_opening( ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2183,7 +2184,7 @@ async fn test_leaving_project( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -2316,7 +2317,7 @@ async fn test_collaborating_with_diagnostics( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -2581,7 +2582,7 @@ async fn test_collaborating_with_diagnostics( #[gpui::test(iterations = 10)] async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2755,7 +2756,7 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu #[gpui::test(iterations = 10)] async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2848,7 +2849,7 @@ async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut Te async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { use project::FormatTrigger; - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2949,7 +2950,7 @@ async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppCon #[gpui::test(iterations = 10)] async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3093,7 +3094,7 @@ async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3194,7 +3195,7 @@ async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3273,7 +3274,7 @@ async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContex #[gpui::test(iterations = 10)] async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3375,7 +3376,7 @@ async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppC #[gpui::test(iterations = 10)] async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3478,7 +3479,7 @@ async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_project_symbols(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3586,7 +3587,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( mut rng: StdRng, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3662,7 +3663,7 @@ async fn test_collaborating_with_code_actions( ) { cx_a.foreground().forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3873,7 +3874,7 @@ async fn test_collaborating_with_code_actions( async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -4065,7 +4066,7 @@ async fn test_language_server_statuses( deterministic.forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -4169,415 +4170,6 @@ async fn test_language_server_statuses( }); } -#[gpui::test(iterations = 10)] -async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - // Create an org that includes these 2 users. - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_org_member(org_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - - // Create a channel that includes all the users. - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - db.create_channel_message( - channel_id, - client_b.current_user_id(cx_b), - "hello A, it's B.", - OffsetDateTime::now_utc(), - 1, - ) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - channels_a.read_with(cx_a, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty())); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - let channels_b = - cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx)); - channels_b - .condition(cx_b, |list, _| list.available_channels().is_some()) - .await; - channels_b.read_with(cx_b, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - - let channel_b = channels_b.update(cx_b, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty())); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("oh, hi B.".to_string(), cx) - .unwrap() - .detach(); - let task = channel.send_message("sup".to_string(), cx).unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), true), - ("user_a".to_string(), "sup".to_string(), true) - ] - ); - task - }) - .await - .unwrap(); - - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ] - }) - .await; - - assert_eq!( - server - .store() - .await - .channel(channel_id) - .unwrap() - .connection_ids - .len(), - 2 - ); - cx_b.update(|_| drop(channel_b)); - server - .condition(|state| state.channel(channel_id).unwrap().connection_ids.len() == 1) - .await; - - cx_a.update(|_| drop(channel_a)); - server - .condition(|state| state.channel(channel_id).is_none()) - .await; -} - -#[gpui::test(iterations = 10)] -async fn test_chat_message_validation(cx_a: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - - // Messages aren't allowed to be too long. - channel_a - .update(cx_a, |channel, cx| { - let long_body = "this is long.\n".repeat(1024); - channel.send_message(long_body, cx).unwrap() - }) - .await - .unwrap_err(); - - // Messages aren't allowed to be blank. - channel_a.update(cx_a, |channel, cx| { - channel.send_message(String::new(), cx).unwrap_err() - }); - - // Leading and trailing whitespace are trimmed. - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("\n surrounded by whitespace \n".to_string(), cx) - .unwrap() - }) - .await - .unwrap(); - assert_eq!( - db.get_channel_messages(channel_id, 10, None) - .await - .unwrap() - .iter() - .map(|m| &m.body) - .collect::>(), - &["surrounded by whitespace"] - ); -} - -#[gpui::test(iterations = 10)] -async fn test_chat_reconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let mut status_b = client_b.status(); - - // Create an org that includes these 2 users. - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_org_member(org_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - - // Create a channel that includes all the users. - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - db.create_channel_message( - channel_id, - client_b.current_user_id(cx_b), - "hello A, it's B.", - OffsetDateTime::now_utc(), - 2, - ) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - - channels_a.read_with(cx_a, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty())); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - let channels_b = - cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx)); - channels_b - .condition(cx_b, |list, _| list.available_channels().is_some()) - .await; - channels_b.read_with(cx_b, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - - let channel_b = channels_b.update(cx_b, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty())); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - // Disconnect client B, ensuring we can still access its cached channel data. - server.forbid_connections(); - server.disconnect_client(client_b.peer_id().unwrap()); - cx_b.foreground().advance_clock(rpc::RECEIVE_TIMEOUT); - while !matches!( - status_b.next().await, - Some(client::Status::ReconnectionError { .. }) - ) {} - - channels_b.read_with(cx_b, |channels, _| { - assert_eq!( - channels.available_channels().unwrap(), - [ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - channel_b.read_with(cx_b, |channel, _| { - assert_eq!( - channel_messages(channel), - [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - ) - }); - - // Send a message from client B while it is disconnected. - channel_b - .update(cx_b, |channel, cx| { - let task = channel - .send_message("can you see this?".to_string(), cx) - .unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), true) - ] - ); - task - }) - .await - .unwrap_err(); - - // Send a message from client A while B is disconnected. - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("oh, hi B.".to_string(), cx) - .unwrap() - .detach(); - let task = channel.send_message("sup".to_string(), cx).unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), true), - ("user_a".to_string(), "sup".to_string(), true) - ] - ); - task - }) - .await - .unwrap(); - - // Give client B a chance to reconnect. - server.allow_connections(); - cx_b.foreground().advance_clock(Duration::from_secs(10)); - - // Verify that B sees the new messages upon reconnection, as well as the message client B - // sent while offline. - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ] - }) - .await; - - // Ensure client A and B can communicate normally after reconnection. - channel_a - .update(cx_a, |channel, cx| { - channel.send_message("you online?".to_string(), cx).unwrap() - }) - .await - .unwrap(); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ("user_a".to_string(), "you online?".to_string(), false), - ] - }) - .await; - - channel_b - .update(cx_b, |channel, cx| { - channel.send_message("yep".to_string(), cx).unwrap() - }) - .await - .unwrap(); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ("user_a".to_string(), "you online?".to_string(), false), - ("user_b".to_string(), "yep".to_string(), false), - ] - }) - .await; -} - #[gpui::test(iterations = 10)] async fn test_contacts( deterministic: Arc, @@ -4586,7 +4178,7 @@ async fn test_contacts( cx_c: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -4912,7 +4504,7 @@ async fn test_contact_requests( cx_a.foreground().forbid_parking(); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_a2 = server.create_client(cx_a2, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; @@ -5093,7 +4685,7 @@ async fn test_following( cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5367,7 +4959,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5545,7 +5137,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont cx_b.update(editor::init); // 2 clients connect to a server. - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5719,7 +5311,7 @@ async fn test_peers_simultaneously_following_each_other( cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5789,7 +5381,7 @@ async fn test_random_collaboration( .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let mut server = TestServer::start(cx.foreground(), cx.background()).await; + let mut server = TestServer::start(cx.background()).await; let db = server.app_state.db.clone(); let mut available_guests = Vec::new(); @@ -6076,8 +5668,6 @@ struct TestServer { peer: Arc, app_state: Arc, server: Arc, - foreground: Rc, - notifications: mpsc::UnboundedReceiver<()>, connection_killers: Arc>>>, forbid_connections: Arc, _test_db: TestDb, @@ -6085,13 +5675,10 @@ struct TestServer { } impl TestServer { - async fn start( - foreground: Rc, - background: Arc, - ) -> Self { + async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::fake(background.clone()); + let test_db = TestDb::new(background.clone()); let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), @@ -6102,14 +5689,11 @@ impl TestServer { .unwrap(); let app_state = Self::build_app_state(&test_db, &live_kit_server).await; let peer = Peer::new(); - let notifications = mpsc::unbounded(); - let server = Server::new(app_state.clone(), Some(notifications.0)); + let server = Server::new(app_state.clone()); Self { peer, app_state, server, - foreground, - notifications: notifications.1, connection_killers: Default::default(), forbid_connections: Default::default(), _test_db: test_db, @@ -6147,7 +5731,7 @@ impl TestServer { }, ) .await - .unwrap() + .expect("creating user failed") .user_id }; let client_name = name.to_string(); @@ -6187,7 +5771,11 @@ impl TestServer { let (client_conn, server_conn, killed) = Connection::in_memory(cx.background()); let (connection_id_tx, connection_id_rx) = oneshot::channel(); - let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + let user = db + .get_user_by_id(user_id) + .await + .expect("retrieving user failed") + .unwrap(); cx.background() .spawn(server.handle_connection( server_conn, @@ -6221,7 +5809,6 @@ impl TestServer { default_item_factory: |_, _| unimplemented!(), }); - Channel::init(&client); Project::init(&client); cx.update(|cx| { workspace::init(app_state.clone(), cx); @@ -6322,21 +5909,6 @@ impl TestServer { config: Default::default(), }) } - - async fn condition(&mut self, mut predicate: F) - where - F: FnMut(&Store) -> bool, - { - assert!( - self.foreground.parking_forbidden(), - "you must call forbid_parking to use server conditions so we don't block indefinitely" - ); - while !(predicate)(&*self.server.store.lock().await) { - self.foreground.start_waiting(); - self.notifications.next().await; - self.foreground.finish_waiting(); - } - } } impl Deref for TestServer { @@ -7052,20 +6624,6 @@ impl Executor for Arc { } } -fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { - channel - .messages() - .cursor::<()>() - .map(|m| { - ( - m.sender.github_login.clone(), - m.body.clone(), - m.is_pending(), - ) - }) - .collect() -} - #[derive(Debug, Eq, PartialEq)] struct RoomParticipants { remote: Vec, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 8085fd8026f443ec80d15792fe62062f6193bae1..dc98a2ee6855c072f5adc9ed95dbad38626eca48 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -13,12 +13,12 @@ use crate::rpc::ResultExt as _; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; -use db::{Db, PostgresDb}; +use db::DefaultDb as Db; use serde::Deserialize; use std::{ env::args, net::{SocketAddr, TcpListener}, - path::PathBuf, + path::{Path, PathBuf}, sync::Arc, time::Duration, }; @@ -49,14 +49,14 @@ pub struct MigrateConfig { } pub struct AppState { - db: Arc, + db: Arc, live_kit_client: Option>, config: Config, } impl AppState { async fn new(config: Config) -> Result> { - let db = PostgresDb::new(&config.database_url, 5).await?; + let db = Db::new(&config.database_url, 5).await?; let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() @@ -96,13 +96,12 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - let db = PostgresDb::new(&config.database_url, 5).await?; + let db = Db::new(&config.database_url, 5).await?; let migrations_path = config .migrations_path .as_deref() - .or(db::DEFAULT_MIGRATIONS_PATH.map(|s| s.as_ref())) - .ok_or_else(|| anyhow!("missing MIGRATIONS_PATH environment variable"))?; + .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"))); let migrations = db.migrate(&migrations_path, false).await?; for (migration, duration) in migrations { @@ -122,9 +121,7 @@ async fn main() -> Result<()> { let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); - let rpc_server = rpc::Server::new(state.clone(), None); - rpc_server - .start_recording_project_activity(Duration::from_secs(5 * 60), rpc::RealExecutor); + let rpc_server = rpc::Server::new(state.clone()); let app = api::routes(rpc_server.clone(), state.clone()) .merge(rpc::routes(rpc_server.clone())) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9fd9bef825e0176dc7db6ab71244a40b2272e191..7bc2b43b9b4c24cdf991e88c90a4e966927a8cfd 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod store; use crate::{ auth, - db::{self, ChannelId, MessageId, ProjectId, User, UserId}, + db::{self, ProjectId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -24,7 +24,7 @@ use axum::{ }; use collections::{HashMap, HashSet}; use futures::{ - channel::{mpsc, oneshot}, + channel::oneshot, future::{self, BoxFuture}, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt, TryStreamExt, @@ -51,7 +51,6 @@ use std::{ time::Duration, }; pub use store::{Store, Worktree}; -use time::OffsetDateTime; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -62,10 +61,6 @@ use tracing::{info_span, instrument, Instrument}; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = register_int_gauge!("connections", "number of connections").unwrap(); - static ref METRIC_REGISTERED_PROJECTS: IntGauge = - register_int_gauge!("registered_projects", "number of registered projects").unwrap(); - static ref METRIC_ACTIVE_PROJECTS: IntGauge = - register_int_gauge!("active_projects", "number of active projects").unwrap(); static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!( "shared_projects", "number of open projects with one or more guests" @@ -95,7 +90,6 @@ pub struct Server { pub(crate) store: Mutex, app_state: Arc, handlers: HashMap, - notifications: Option>, } pub trait Executor: Send + Clone { @@ -107,9 +101,6 @@ pub trait Executor: Send + Clone { #[derive(Clone)] pub struct RealExecutor; -const MESSAGE_COUNT_PER_PAGE: usize = 100; -const MAX_MESSAGE_LEN: usize = 1024; - pub(crate) struct StoreGuard<'a> { guard: MutexGuard<'a, Store>, _not_send: PhantomData>, @@ -132,16 +123,12 @@ where } impl Server { - pub fn new( - app_state: Arc, - notifications: Option>, - ) -> Arc { + pub fn new(app_state: Arc) -> Arc { let mut server = Self { peer: Peer::new(), app_state, store: Default::default(), handlers: Default::default(), - notifications, }; server @@ -158,9 +145,7 @@ impl Server { .add_request_handler(Server::join_project) .add_message_handler(Server::leave_project) .add_message_handler(Server::update_project) - .add_message_handler(Server::register_project_activity) .add_request_handler(Server::update_worktree) - .add_message_handler(Server::update_worktree_extensions) .add_message_handler(Server::start_language_server) .add_message_handler(Server::update_language_server) .add_message_handler(Server::update_diagnostic_summary) @@ -194,19 +179,14 @@ impl Server { .add_message_handler(Server::buffer_reloaded) .add_message_handler(Server::buffer_saved) .add_request_handler(Server::save_buffer) - .add_request_handler(Server::get_channels) .add_request_handler(Server::get_users) .add_request_handler(Server::fuzzy_search_users) .add_request_handler(Server::request_contact) .add_request_handler(Server::remove_contact) .add_request_handler(Server::respond_to_contact_request) - .add_request_handler(Server::join_channel) - .add_message_handler(Server::leave_channel) - .add_request_handler(Server::send_channel_message) .add_request_handler(Server::follow) .add_message_handler(Server::unfollow) .add_message_handler(Server::update_followers) - .add_request_handler(Server::get_channel_messages) .add_message_handler(Server::update_diff_base) .add_request_handler(Server::get_private_user_info); @@ -290,58 +270,6 @@ impl Server { }) } - /// Start a long lived task that records which users are active in which projects. - pub fn start_recording_project_activity( - self: &Arc, - interval: Duration, - executor: E, - ) { - executor.spawn_detached({ - let this = Arc::downgrade(self); - let executor = executor.clone(); - async move { - let mut period_start = OffsetDateTime::now_utc(); - let mut active_projects = Vec::<(UserId, ProjectId)>::new(); - loop { - let sleep = executor.sleep(interval); - sleep.await; - let this = if let Some(this) = this.upgrade() { - this - } else { - break; - }; - - active_projects.clear(); - active_projects.extend(this.store().await.projects().flat_map( - |(project_id, project)| { - project.guests.values().chain([&project.host]).filter_map( - |collaborator| { - if !collaborator.admin - && collaborator - .last_activity - .map_or(false, |activity| activity > period_start) - { - Some((collaborator.user_id, *project_id)) - } else { - None - } - }, - ) - }, - )); - - let period_end = OffsetDateTime::now_utc(); - this.app_state - .db - .record_user_activity(period_start..period_end, &active_projects) - .await - .trace_err(); - period_start = period_end; - } - } - }); - } - pub fn handle_connection( self: &Arc, connection: Connection, @@ -432,18 +360,11 @@ impl Server { let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { - let notifications = this.notifications.clone(); let is_background = message.is_background(); let handle_message = (handler)(this.clone(), message); - drop(span_enter); - let handle_message = async move { - handle_message.await; - if let Some(mut notifications) = notifications { - let _ = notifications.send(()).await; - } - }.instrument(span); + let handle_message = handle_message.instrument(span); if is_background { executor.spawn_detached(handle_message); } else { @@ -1172,17 +1093,6 @@ impl Server { Ok(()) } - async fn register_project_activity( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - self.store().await.register_project_activity( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - Ok(()) - } - async fn update_worktree( self: Arc, request: TypedEnvelope, @@ -1209,25 +1119,6 @@ impl Server { Ok(()) } - async fn update_worktree_extensions( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let worktree_id = request.payload.worktree_id; - let extensions = request - .payload - .extensions - .into_iter() - .zip(request.payload.counts) - .collect(); - self.app_state - .db - .update_worktree_extensions(project_id, worktree_id, extensions) - .await?; - Ok(()) - } - async fn update_diagnostic_summary( self: Arc, request: TypedEnvelope, @@ -1363,8 +1254,7 @@ impl Server { ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let receiver_ids = { - let mut store = self.store().await; - store.register_project_activity(project_id, request.sender_id)?; + let store = self.store().await; store.project_connection_ids(project_id, request.sender_id)? }; @@ -1430,15 +1320,13 @@ impl Server { let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_id; { - let mut store = self.store().await; + let store = self.store().await; if !store .project_connection_ids(project_id, follower_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } - - store.register_project_activity(project_id, follower_id)?; } let mut response_payload = self @@ -1455,14 +1343,13 @@ impl Server { async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let mut store = self.store().await; + let store = self.store().await; if !store .project_connection_ids(project_id, request.sender_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } - store.register_project_activity(project_id, request.sender_id)?; self.peer .forward_send(request.sender_id, leader_id, request.payload)?; Ok(()) @@ -1473,8 +1360,7 @@ impl Server { request: TypedEnvelope, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let mut store = self.store().await; - store.register_project_activity(project_id, request.sender_id)?; + let store = self.store().await; let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; let leader_id = request .payload @@ -1495,28 +1381,6 @@ impl Server { Ok(()) } - async fn get_channels( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channels = self.app_state.db.get_accessible_channels(user_id).await?; - response.send(proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - })?; - Ok(()) - } - async fn get_users( self: Arc, request: TypedEnvelope, @@ -1712,175 +1576,6 @@ impl Server { Ok(()) } - async fn join_channel( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - self.store() - .await - .join_channel(request.sender_id, channel_id); - let messages = self - .app_state - .db - .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None) - .await? - .into_iter() - .map(|msg| proto::ChannelMessage { - id: msg.id.to_proto(), - body: msg.body, - timestamp: msg.sent_at.unix_timestamp() as u64, - sender_id: msg.sender_id.to_proto(), - nonce: Some(msg.nonce.as_u128().into()), - }) - .collect::>(); - response.send(proto::JoinChannelResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) - } - - async fn leave_channel( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - self.store() - .await - .leave_channel(request.sender_id, channel_id); - - Ok(()) - } - - async fn send_channel_message( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.payload.channel_id); - let user_id; - let connection_ids; - { - let state = self.store().await; - user_id = state.user_id_for_connection(request.sender_id)?; - connection_ids = state.channel_connection_ids(channel_id)?; - } - - // Validate the message body. - let body = request.payload.body.trim().to_string(); - if body.len() > MAX_MESSAGE_LEN { - return Err(anyhow!("message is too long"))?; - } - if body.is_empty() { - return Err(anyhow!("message can't be blank"))?; - } - - let timestamp = OffsetDateTime::now_utc(); - let nonce = request - .payload - .nonce - .ok_or_else(|| anyhow!("nonce can't be blank"))?; - - let message_id = self - .app_state - .db - .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into()) - .await? - .to_proto(); - let message = proto::ChannelMessage { - sender_id: user_id.to_proto(), - id: message_id, - body, - timestamp: timestamp.unix_timestamp() as u64, - nonce: Some(nonce), - }; - broadcast(request.sender_id, connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); - response.send(proto::SendChannelMessageResponse { - message: Some(message), - })?; - Ok(()) - } - - async fn get_channel_messages( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - let messages = self - .app_state - .db - .get_channel_messages( - channel_id, - MESSAGE_COUNT_PER_PAGE, - Some(MessageId::from_proto(request.payload.before_message_id)), - ) - .await? - .into_iter() - .map(|msg| proto::ChannelMessage { - id: msg.id.to_proto(), - body: msg.body, - timestamp: msg.sent_at.unix_timestamp() as u64, - sender_id: msg.sender_id.to_proto(), - nonce: Some(msg.nonce.as_u128().into()), - }) - .collect::>(); - response.send(proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) - } - async fn update_diff_base( self: Arc, request: TypedEnvelope, @@ -2061,11 +1756,8 @@ pub async fn handle_websocket_request( } pub async fn handle_metrics(Extension(server): Extension>) -> axum::response::Response { - // We call `store_mut` here for its side effects of updating metrics. let metrics = server.store().await.metrics(); METRIC_CONNECTIONS.set(metrics.connections as _); - METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _); - METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _); METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _); let encoder = prometheus::TextEncoder::new(); diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index c9358ddc2a356e3f1c3e624f1851d955287c840c..81ef594ccd75b4098ec48af7f1c8a93b260d523b 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,11 +1,10 @@ -use crate::db::{self, ChannelId, ProjectId, UserId}; +use crate::db::{self, ProjectId, UserId}; use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; use nanoid::nanoid; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::{borrow::Cow, mem, path::PathBuf, str, time::Duration}; -use time::OffsetDateTime; +use std::{borrow::Cow, mem, path::PathBuf, str}; use tracing::instrument; use util::post_inc; @@ -18,8 +17,6 @@ pub struct Store { next_room_id: RoomId, rooms: BTreeMap, projects: BTreeMap, - #[serde(skip)] - channels: BTreeMap, } #[derive(Default, Serialize)] @@ -33,7 +30,6 @@ struct ConnectionState { user_id: UserId, admin: bool, projects: BTreeSet, - channels: HashSet, } #[derive(Copy, Clone, Eq, PartialEq, Serialize)] @@ -60,8 +56,6 @@ pub struct Project { pub struct Collaborator { pub replica_id: ReplicaId, pub user_id: UserId, - #[serde(skip)] - pub last_activity: Option, pub admin: bool, } @@ -78,11 +72,6 @@ pub struct Worktree { pub is_complete: bool, } -#[derive(Default)] -pub struct Channel { - pub connection_ids: HashSet, -} - pub type ReplicaId = u16; #[derive(Default)] @@ -113,38 +102,23 @@ pub struct LeftRoom<'a> { #[derive(Copy, Clone)] pub struct Metrics { pub connections: usize, - pub registered_projects: usize, - pub active_projects: usize, pub shared_projects: usize, } impl Store { pub fn metrics(&self) -> Metrics { - const ACTIVE_PROJECT_TIMEOUT: Duration = Duration::from_secs(60); - let active_window_start = OffsetDateTime::now_utc() - ACTIVE_PROJECT_TIMEOUT; - let connections = self.connections.values().filter(|c| !c.admin).count(); - let mut registered_projects = 0; - let mut active_projects = 0; let mut shared_projects = 0; for project in self.projects.values() { if let Some(connection) = self.connections.get(&project.host_connection_id) { if !connection.admin { - registered_projects += 1; - if project.is_active_since(active_window_start) { - active_projects += 1; - if !project.guests.is_empty() { - shared_projects += 1; - } - } + shared_projects += 1; } } } Metrics { connections, - registered_projects, - active_projects, shared_projects, } } @@ -162,7 +136,6 @@ impl Store { user_id, admin, projects: Default::default(), - channels: Default::default(), }, ); let connected_user = self.connected_users.entry(user_id).or_default(); @@ -201,18 +174,12 @@ impl Store { .ok_or_else(|| anyhow!("no such connection"))?; let user_id = connection.user_id; - let connection_channels = mem::take(&mut connection.channels); let mut result = RemovedConnectionState { user_id, ..Default::default() }; - // Leave all channels. - for channel_id in connection_channels { - self.leave_channel(connection_id, channel_id); - } - let connected_user = self.connected_users.get(&user_id).unwrap(); if let Some(active_call) = connected_user.active_call.as_ref() { let room_id = active_call.room_id; @@ -238,34 +205,6 @@ impl Store { Ok(result) } - #[cfg(test)] - pub fn channel(&self, id: ChannelId) -> Option<&Channel> { - self.channels.get(&id) - } - - pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.insert(channel_id); - self.channels - .entry(channel_id) - .or_default() - .connection_ids - .insert(connection_id); - } - } - - pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.remove(&channel_id); - if let btree_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) { - entry.get_mut().connection_ids.remove(&connection_id); - if entry.get_mut().connection_ids.is_empty() { - entry.remove(); - } - } - } - } - pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result { Ok(self .connections @@ -760,7 +699,6 @@ impl Store { host: Collaborator { user_id: connection.user_id, replica_id: 0, - last_activity: None, admin: connection.admin, }, guests: Default::default(), @@ -959,12 +897,10 @@ impl Store { Collaborator { replica_id, user_id: connection.user_id, - last_activity: Some(OffsetDateTime::now_utc()), admin: connection.admin, }, ); - project.host.last_activity = Some(OffsetDateTime::now_utc()); Ok((project, replica_id)) } @@ -1056,44 +992,12 @@ impl Store { .connection_ids()) } - pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result> { - Ok(self - .channels - .get(&channel_id) - .ok_or_else(|| anyhow!("no such channel"))? - .connection_ids()) - } - pub fn project(&self, project_id: ProjectId) -> Result<&Project> { self.projects .get(&project_id) .ok_or_else(|| anyhow!("no such project")) } - pub fn register_project_activity( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<()> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - let collaborator = if connection_id == project.host_connection_id { - &mut project.host - } else if let Some(guest) = project.guests.get_mut(&connection_id) { - guest - } else { - return Err(anyhow!("no such project"))?; - }; - collaborator.last_activity = Some(OffsetDateTime::now_utc()); - Ok(()) - } - - pub fn projects(&self) -> impl Iterator { - self.projects.iter() - } - pub fn read_project( &self, project_id: ProjectId, @@ -1154,10 +1058,7 @@ impl Store { } } } - for channel_id in &connection.channels { - let channel = self.channels.get(channel_id).unwrap(); - assert!(channel.connection_ids.contains(connection_id)); - } + assert!(self .connected_users .get(&connection.user_id) @@ -1253,28 +1154,10 @@ impl Store { "project was not shared in room" ); } - - for (channel_id, channel) in &self.channels { - for connection_id in &channel.connection_ids { - let connection = self.connections.get(connection_id).unwrap(); - assert!(connection.channels.contains(channel_id)); - } - } } } impl Project { - fn is_active_since(&self, start_time: OffsetDateTime) -> bool { - self.guests - .values() - .chain([&self.host]) - .any(|collaborator| { - collaborator - .last_activity - .map_or(false, |active_time| active_time > start_time) - }) - } - pub fn guest_connection_ids(&self) -> Vec { self.guests.keys().copied().collect() } @@ -1287,9 +1170,3 @@ impl Project { .collect() } } - -impl Channel { - fn connection_ids(&self) -> Vec { - self.connection_ids.iter().copied().collect() - } -} diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 74a38599ece2fc139328c893a2af1ec4927f5487..e849632a2df38945fcf34bf8b5967491f19df9e9 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -115,7 +115,6 @@ fn main() { context_menu::init(cx); project::Project::init(&client); - client::Channel::init(&client); client::init(client.clone(), cx); command_palette::init(cx); editor::init(cx);