Detailed changes
@@ -45,7 +45,10 @@ jobs:
- name: Run tests
run: cargo test --workspace --no-fail-fast
- - name: Build collab binaries
+ - name: Build collab
+ run: cargo build -p collab
+
+ - name: Build other binaries
run: cargo build --bins --all-features
bundle:
@@ -1953,6 +1953,18 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e"
+[[package]]
+name = "flume"
+version = "0.10.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577"
+dependencies = [
+ "futures-core",
+ "futures-sink",
+ "pin-project",
+ "spin 0.9.4",
+]
+
[[package]]
name = "fnv"
version = "1.0.7"
@@ -3022,7 +3034,7 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
dependencies = [
- "spin",
+ "spin 0.5.2",
]
[[package]]
@@ -4725,7 +4737,7 @@ dependencies = [
"cc",
"libc",
"once_cell",
- "spin",
+ "spin 0.5.2",
"untrusted",
"web-sys",
"winapi 0.3.9",
@@ -5563,6 +5575,15 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
+[[package]]
+name = "spin"
+version = "0.9.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f6002a767bff9e83f8eeecf883ecb8011875a21ae8da43bffb817a57e78cc09"
+dependencies = [
+ "lock_api",
+]
+
[[package]]
name = "spsc-buffer"
version = "0.1.1"
@@ -5583,8 +5604,7 @@ dependencies = [
[[package]]
name = "sqlx"
version = "0.6.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9249290c05928352f71c077cc44a464d880c63f26f7534728cca008e135c0428"
+source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3"
dependencies = [
"sqlx-core",
"sqlx-macros",
@@ -5593,8 +5613,7 @@ dependencies = [
[[package]]
name = "sqlx-core"
version = "0.6.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dcbc16ddba161afc99e14d1713a453747a2b07fc097d2009f4c300ec99286105"
+source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3"
dependencies = [
"ahash",
"atoi",
@@ -5608,8 +5627,10 @@ dependencies = [
"dotenvy",
"either",
"event-listener",
+ "flume",
"futures-channel",
"futures-core",
+ "futures-executor",
"futures-intrusive",
"futures-util",
"hashlink",
@@ -5619,6 +5640,7 @@ dependencies = [
"indexmap",
"itoa",
"libc",
+ "libsqlite3-sys",
"log",
"md-5",
"memchr",
@@ -5648,8 +5670,7 @@ dependencies = [
[[package]]
name = "sqlx-macros"
version = "0.6.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b850fa514dc11f2ee85be9d055c512aa866746adfacd1cb42d867d68e6a5b0d9"
+source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3"
dependencies = [
"dotenvy",
"either",
@@ -5657,6 +5678,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
+ "serde_json",
"sha2 0.10.6",
"sqlx-core",
"sqlx-rt",
@@ -5667,8 +5689,7 @@ dependencies = [
[[package]]
name = "sqlx-rt"
version = "0.6.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "24c5b2d25fa654cc5f841750b8e1cdedbe21189bf9a9382ee90bfa9dd3562396"
+source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3"
dependencies = [
"once_cell",
"tokio",
@@ -1,820 +0,0 @@
-use super::{
- proto,
- user::{User, UserStore},
- Client, Status, Subscription, TypedEnvelope,
-};
-use anyhow::{anyhow, Context, Result};
-use futures::lock::Mutex;
-use gpui::{
- AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
-};
-use postage::prelude::Stream;
-use rand::prelude::*;
-use std::{
- collections::{HashMap, HashSet},
- mem,
- ops::Range,
- sync::Arc,
-};
-use sum_tree::{Bias, SumTree};
-use time::OffsetDateTime;
-use util::{post_inc, ResultExt as _, TryFutureExt};
-
-pub struct ChannelList {
- available_channels: Option<Vec<ChannelDetails>>,
- channels: HashMap<u64, WeakModelHandle<Channel>>,
- client: Arc<Client>,
- user_store: ModelHandle<UserStore>,
- _task: Task<Option<()>>,
-}
-
-#[derive(Clone, Debug, PartialEq)]
-pub struct ChannelDetails {
- pub id: u64,
- pub name: String,
-}
-
-pub struct Channel {
- details: ChannelDetails,
- messages: SumTree<ChannelMessage>,
- loaded_all_messages: bool,
- next_pending_message_id: usize,
- user_store: ModelHandle<UserStore>,
- rpc: Arc<Client>,
- outgoing_messages_lock: Arc<Mutex<()>>,
- rng: StdRng,
- _subscription: Subscription,
-}
-
-#[derive(Clone, Debug)]
-pub struct ChannelMessage {
- pub id: ChannelMessageId,
- pub body: String,
- pub timestamp: OffsetDateTime,
- pub sender: Arc<User>,
- pub nonce: u128,
-}
-
-#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub enum ChannelMessageId {
- Saved(u64),
- Pending(usize),
-}
-
-#[derive(Clone, Debug, Default)]
-pub struct ChannelMessageSummary {
- max_id: ChannelMessageId,
- count: usize,
-}
-
-#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
-struct Count(usize);
-
-pub enum ChannelListEvent {}
-
-#[derive(Clone, Debug, PartialEq)]
-pub enum ChannelEvent {
- MessagesUpdated {
- old_range: Range<usize>,
- new_count: usize,
- },
-}
-
-impl Entity for ChannelList {
- type Event = ChannelListEvent;
-}
-
-impl ChannelList {
- pub fn new(
- user_store: ModelHandle<UserStore>,
- rpc: Arc<Client>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let _task = cx.spawn_weak(|this, mut cx| {
- let rpc = rpc.clone();
- async move {
- let mut status = rpc.status();
- while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
- match status {
- Status::Connected { .. } => {
- let response = rpc
- .request(proto::GetChannels {})
- .await
- .context("failed to fetch available channels")?;
- this.update(&mut cx, |this, cx| {
- this.available_channels =
- Some(response.channels.into_iter().map(Into::into).collect());
-
- let mut to_remove = Vec::new();
- for (channel_id, channel) in &this.channels {
- if let Some(channel) = channel.upgrade(cx) {
- channel.update(cx, |channel, cx| channel.rejoin(cx))
- } else {
- to_remove.push(*channel_id);
- }
- }
-
- for channel_id in to_remove {
- this.channels.remove(&channel_id);
- }
- cx.notify();
- });
- }
- Status::SignedOut { .. } => {
- this.update(&mut cx, |this, cx| {
- this.available_channels = None;
- this.channels.clear();
- cx.notify();
- });
- }
- _ => {}
- }
- }
- Ok(())
- }
- .log_err()
- });
-
- Self {
- available_channels: None,
- channels: Default::default(),
- user_store,
- client: rpc,
- _task,
- }
- }
-
- pub fn available_channels(&self) -> Option<&[ChannelDetails]> {
- self.available_channels.as_deref()
- }
-
- pub fn get_channel(
- &mut self,
- id: u64,
- cx: &mut MutableAppContext,
- ) -> Option<ModelHandle<Channel>> {
- if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) {
- return Some(channel);
- }
-
- let channels = self.available_channels.as_ref()?;
- let details = channels.iter().find(|details| details.id == id)?.clone();
- let channel = cx.add_model(|cx| {
- Channel::new(details, self.user_store.clone(), self.client.clone(), cx)
- });
- self.channels.insert(id, channel.downgrade());
- Some(channel)
- }
-}
-
-impl Entity for Channel {
- type Event = ChannelEvent;
-
- fn release(&mut self, _: &mut MutableAppContext) {
- self.rpc
- .send(proto::LeaveChannel {
- channel_id: self.details.id,
- })
- .log_err();
- }
-}
-
-impl Channel {
- pub fn init(rpc: &Arc<Client>) {
- rpc.add_model_message_handler(Self::handle_message_sent);
- }
-
- pub fn new(
- details: ChannelDetails,
- user_store: ModelHandle<UserStore>,
- rpc: Arc<Client>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let _subscription = rpc.add_model_for_remote_entity(details.id, cx);
-
- {
- let user_store = user_store.clone();
- let rpc = rpc.clone();
- let channel_id = details.id;
- cx.spawn(|channel, mut cx| {
- async move {
- let response = rpc.request(proto::JoinChannel { channel_id }).await?;
- let messages =
- messages_from_proto(response.messages, &user_store, &mut cx).await?;
- let loaded_all_messages = response.done;
-
- channel.update(&mut cx, |channel, cx| {
- channel.insert_messages(messages, cx);
- channel.loaded_all_messages = loaded_all_messages;
- });
-
- Ok(())
- }
- .log_err()
- })
- .detach();
- }
-
- Self {
- details,
- user_store,
- rpc,
- outgoing_messages_lock: Default::default(),
- messages: Default::default(),
- loaded_all_messages: false,
- next_pending_message_id: 0,
- rng: StdRng::from_entropy(),
- _subscription,
- }
- }
-
- pub fn name(&self) -> &str {
- &self.details.name
- }
-
- pub fn send_message(
- &mut self,
- body: String,
- cx: &mut ModelContext<Self>,
- ) -> Result<Task<Result<()>>> {
- if body.is_empty() {
- Err(anyhow!("message body can't be empty"))?;
- }
-
- let current_user = self
- .user_store
- .read(cx)
- .current_user()
- .ok_or_else(|| anyhow!("current_user is not present"))?;
-
- let channel_id = self.details.id;
- let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
- let nonce = self.rng.gen();
- self.insert_messages(
- SumTree::from_item(
- ChannelMessage {
- id: pending_id,
- body: body.clone(),
- sender: current_user,
- timestamp: OffsetDateTime::now_utc(),
- nonce,
- },
- &(),
- ),
- cx,
- );
- let user_store = self.user_store.clone();
- let rpc = self.rpc.clone();
- let outgoing_messages_lock = self.outgoing_messages_lock.clone();
- Ok(cx.spawn(|this, mut cx| async move {
- let outgoing_message_guard = outgoing_messages_lock.lock().await;
- let request = rpc.request(proto::SendChannelMessage {
- channel_id,
- body,
- nonce: Some(nonce.into()),
- });
- let response = request.await?;
- drop(outgoing_message_guard);
- let message = ChannelMessage::from_proto(
- response.message.ok_or_else(|| anyhow!("invalid message"))?,
- &user_store,
- &mut cx,
- )
- .await?;
- this.update(&mut cx, |this, cx| {
- this.insert_messages(SumTree::from_item(message, &()), cx);
- Ok(())
- })
- }))
- }
-
- pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
- if !self.loaded_all_messages {
- let rpc = self.rpc.clone();
- let user_store = self.user_store.clone();
- let channel_id = self.details.id;
- if let Some(before_message_id) =
- self.messages.first().and_then(|message| match message.id {
- ChannelMessageId::Saved(id) => Some(id),
- ChannelMessageId::Pending(_) => None,
- })
- {
- cx.spawn(|this, mut cx| {
- async move {
- let response = rpc
- .request(proto::GetChannelMessages {
- channel_id,
- before_message_id,
- })
- .await?;
- let loaded_all_messages = response.done;
- let messages =
- messages_from_proto(response.messages, &user_store, &mut cx).await?;
- this.update(&mut cx, |this, cx| {
- this.loaded_all_messages = loaded_all_messages;
- this.insert_messages(messages, cx);
- });
- Ok(())
- }
- .log_err()
- })
- .detach();
- return true;
- }
- }
- false
- }
-
- pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
- let user_store = self.user_store.clone();
- let rpc = self.rpc.clone();
- let channel_id = self.details.id;
- cx.spawn(|this, mut cx| {
- async move {
- let response = rpc.request(proto::JoinChannel { channel_id }).await?;
- let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
- let loaded_all_messages = response.done;
-
- let pending_messages = this.update(&mut cx, |this, cx| {
- if let Some((first_new_message, last_old_message)) =
- messages.first().zip(this.messages.last())
- {
- if first_new_message.id > last_old_message.id {
- let old_messages = mem::take(&mut this.messages);
- cx.emit(ChannelEvent::MessagesUpdated {
- old_range: 0..old_messages.summary().count,
- new_count: 0,
- });
- this.loaded_all_messages = loaded_all_messages;
- }
- }
-
- this.insert_messages(messages, cx);
- if loaded_all_messages {
- this.loaded_all_messages = loaded_all_messages;
- }
-
- this.pending_messages().cloned().collect::<Vec<_>>()
- });
-
- for pending_message in pending_messages {
- let request = rpc.request(proto::SendChannelMessage {
- channel_id,
- body: pending_message.body,
- nonce: Some(pending_message.nonce.into()),
- });
- let response = request.await?;
- let message = ChannelMessage::from_proto(
- response.message.ok_or_else(|| anyhow!("invalid message"))?,
- &user_store,
- &mut cx,
- )
- .await?;
- this.update(&mut cx, |this, cx| {
- this.insert_messages(SumTree::from_item(message, &()), cx);
- });
- }
-
- Ok(())
- }
- .log_err()
- })
- .detach();
- }
-
- pub fn message_count(&self) -> usize {
- self.messages.summary().count
- }
-
- pub fn messages(&self) -> &SumTree<ChannelMessage> {
- &self.messages
- }
-
- pub fn message(&self, ix: usize) -> &ChannelMessage {
- let mut cursor = self.messages.cursor::<Count>();
- cursor.seek(&Count(ix), Bias::Right, &());
- cursor.item().unwrap()
- }
-
- pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
- let mut cursor = self.messages.cursor::<Count>();
- cursor.seek(&Count(range.start), Bias::Right, &());
- cursor.take(range.len())
- }
-
- pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
- let mut cursor = self.messages.cursor::<ChannelMessageId>();
- cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
- cursor
- }
-
- async fn handle_message_sent(
- this: ModelHandle<Self>,
- message: TypedEnvelope<proto::ChannelMessageSent>,
- _: Arc<Client>,
- mut cx: AsyncAppContext,
- ) -> Result<()> {
- let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
- let message = message
- .payload
- .message
- .ok_or_else(|| anyhow!("empty message"))?;
-
- let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
- this.update(&mut cx, |this, cx| {
- this.insert_messages(SumTree::from_item(message, &()), cx)
- });
-
- Ok(())
- }
-
- fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
- if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
- let nonces = messages
- .cursor::<()>()
- .map(|m| m.nonce)
- .collect::<HashSet<_>>();
-
- let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
- let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
- let start_ix = old_cursor.start().1 .0;
- let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
- let removed_count = removed_messages.summary().count;
- let new_count = messages.summary().count;
- let end_ix = start_ix + removed_count;
-
- new_messages.push_tree(messages, &());
-
- let mut ranges = Vec::<Range<usize>>::new();
- if new_messages.last().unwrap().is_pending() {
- new_messages.push_tree(old_cursor.suffix(&()), &());
- } else {
- new_messages.push_tree(
- old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
- &(),
- );
-
- while let Some(message) = old_cursor.item() {
- let message_ix = old_cursor.start().1 .0;
- if nonces.contains(&message.nonce) {
- if ranges.last().map_or(false, |r| r.end == message_ix) {
- ranges.last_mut().unwrap().end += 1;
- } else {
- ranges.push(message_ix..message_ix + 1);
- }
- } else {
- new_messages.push(message.clone(), &());
- }
- old_cursor.next(&());
- }
- }
-
- drop(old_cursor);
- self.messages = new_messages;
-
- for range in ranges.into_iter().rev() {
- cx.emit(ChannelEvent::MessagesUpdated {
- old_range: range,
- new_count: 0,
- });
- }
- cx.emit(ChannelEvent::MessagesUpdated {
- old_range: start_ix..end_ix,
- new_count,
- });
- cx.notify();
- }
- }
-}
-
-async fn messages_from_proto(
- proto_messages: Vec<proto::ChannelMessage>,
- user_store: &ModelHandle<UserStore>,
- cx: &mut AsyncAppContext,
-) -> Result<SumTree<ChannelMessage>> {
- let unique_user_ids = proto_messages
- .iter()
- .map(|m| m.sender_id)
- .collect::<HashSet<_>>()
- .into_iter()
- .collect();
- user_store
- .update(cx, |user_store, cx| {
- user_store.get_users(unique_user_ids, cx)
- })
- .await?;
-
- let mut messages = Vec::with_capacity(proto_messages.len());
- for message in proto_messages {
- messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
- }
- let mut result = SumTree::new();
- result.extend(messages, &());
- Ok(result)
-}
-
-impl From<proto::Channel> for ChannelDetails {
- fn from(message: proto::Channel) -> Self {
- Self {
- id: message.id,
- name: message.name,
- }
- }
-}
-
-impl ChannelMessage {
- pub async fn from_proto(
- message: proto::ChannelMessage,
- user_store: &ModelHandle<UserStore>,
- cx: &mut AsyncAppContext,
- ) -> Result<Self> {
- let sender = user_store
- .update(cx, |user_store, cx| {
- user_store.get_user(message.sender_id, cx)
- })
- .await?;
- Ok(ChannelMessage {
- id: ChannelMessageId::Saved(message.id),
- body: message.body,
- timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
- sender,
- nonce: message
- .nonce
- .ok_or_else(|| anyhow!("nonce is required"))?
- .into(),
- })
- }
-
- pub fn is_pending(&self) -> bool {
- matches!(self.id, ChannelMessageId::Pending(_))
- }
-}
-
-impl sum_tree::Item for ChannelMessage {
- type Summary = ChannelMessageSummary;
-
- fn summary(&self) -> Self::Summary {
- ChannelMessageSummary {
- max_id: self.id,
- count: 1,
- }
- }
-}
-
-impl Default for ChannelMessageId {
- fn default() -> Self {
- Self::Saved(0)
- }
-}
-
-impl sum_tree::Summary for ChannelMessageSummary {
- type Context = ();
-
- fn add_summary(&mut self, summary: &Self, _: &()) {
- self.max_id = summary.max_id;
- self.count += summary.count;
- }
-}
-
-impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
- fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
- debug_assert!(summary.max_id > *self);
- *self = summary.max_id;
- }
-}
-
-impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
- fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
- self.0 += summary.count;
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::test::{FakeHttpClient, FakeServer};
- use gpui::TestAppContext;
-
- #[gpui::test]
- async fn test_channel_messages(cx: &mut TestAppContext) {
- cx.foreground().forbid_parking();
-
- let user_id = 5;
- let http_client = FakeHttpClient::with_404_response();
- let client = cx.update(|cx| Client::new(http_client.clone(), cx));
- let server = FakeServer::for_client(user_id, &client, cx).await;
-
- Channel::init(&client);
- let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx));
-
- let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
- channel_list.read_with(cx, |list, _| assert_eq!(list.available_channels(), None));
-
- // Get the available channels.
- let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
- server
- .respond(
- get_channels.receipt(),
- proto::GetChannelsResponse {
- channels: vec![proto::Channel {
- id: 5,
- name: "the-channel".to_string(),
- }],
- },
- )
- .await;
- channel_list.next_notification(cx).await;
- channel_list.read_with(cx, |list, _| {
- assert_eq!(
- list.available_channels().unwrap(),
- &[ChannelDetails {
- id: 5,
- name: "the-channel".into(),
- }]
- )
- });
-
- let get_users = server.receive::<proto::GetUsers>().await.unwrap();
- assert_eq!(get_users.payload.user_ids, vec![5]);
- server
- .respond(
- get_users.receipt(),
- proto::UsersResponse {
- users: vec![proto::User {
- id: 5,
- github_login: "nathansobo".into(),
- avatar_url: "http://avatar.com/nathansobo".into(),
- }],
- },
- )
- .await;
-
- // Join a channel and populate its existing messages.
- let channel = channel_list
- .update(cx, |list, cx| {
- let channel_id = list.available_channels().unwrap()[0].id;
- list.get_channel(channel_id, cx)
- })
- .unwrap();
- channel.read_with(cx, |channel, _| assert!(channel.messages().is_empty()));
- let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
- server
- .respond(
- join_channel.receipt(),
- proto::JoinChannelResponse {
- messages: vec![
- proto::ChannelMessage {
- id: 10,
- body: "a".into(),
- timestamp: 1000,
- sender_id: 5,
- nonce: Some(1.into()),
- },
- proto::ChannelMessage {
- id: 11,
- body: "b".into(),
- timestamp: 1001,
- sender_id: 6,
- nonce: Some(2.into()),
- },
- ],
- done: false,
- },
- )
- .await;
-
- // Client requests all users for the received messages
- let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
- get_users.payload.user_ids.sort();
- assert_eq!(get_users.payload.user_ids, vec![6]);
- server
- .respond(
- get_users.receipt(),
- proto::UsersResponse {
- users: vec![proto::User {
- id: 6,
- github_login: "maxbrunsfeld".into(),
- avatar_url: "http://avatar.com/maxbrunsfeld".into(),
- }],
- },
- )
- .await;
-
- assert_eq!(
- channel.next_event(cx).await,
- ChannelEvent::MessagesUpdated {
- old_range: 0..0,
- new_count: 2,
- }
- );
- channel.read_with(cx, |channel, _| {
- assert_eq!(
- channel
- .messages_in_range(0..2)
- .map(|message| (message.sender.github_login.clone(), message.body.clone()))
- .collect::<Vec<_>>(),
- &[
- ("nathansobo".into(), "a".into()),
- ("maxbrunsfeld".into(), "b".into())
- ]
- );
- });
-
- // Receive a new message.
- server.send(proto::ChannelMessageSent {
- channel_id: channel.read_with(cx, |channel, _| channel.details.id),
- message: Some(proto::ChannelMessage {
- id: 12,
- body: "c".into(),
- timestamp: 1002,
- sender_id: 7,
- nonce: Some(3.into()),
- }),
- });
-
- // Client requests user for message since they haven't seen them yet
- let get_users = server.receive::<proto::GetUsers>().await.unwrap();
- assert_eq!(get_users.payload.user_ids, vec![7]);
- server
- .respond(
- get_users.receipt(),
- proto::UsersResponse {
- users: vec![proto::User {
- id: 7,
- github_login: "as-cii".into(),
- avatar_url: "http://avatar.com/as-cii".into(),
- }],
- },
- )
- .await;
-
- assert_eq!(
- channel.next_event(cx).await,
- ChannelEvent::MessagesUpdated {
- old_range: 2..2,
- new_count: 1,
- }
- );
- channel.read_with(cx, |channel, _| {
- assert_eq!(
- channel
- .messages_in_range(2..3)
- .map(|message| (message.sender.github_login.clone(), message.body.clone()))
- .collect::<Vec<_>>(),
- &[("as-cii".into(), "c".into())]
- )
- });
-
- // Scroll up to view older messages.
- channel.update(cx, |channel, cx| {
- assert!(channel.load_more_messages(cx));
- });
- let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
- assert_eq!(get_messages.payload.channel_id, 5);
- assert_eq!(get_messages.payload.before_message_id, 10);
- server
- .respond(
- get_messages.receipt(),
- proto::GetChannelMessagesResponse {
- done: true,
- messages: vec![
- proto::ChannelMessage {
- id: 8,
- body: "y".into(),
- timestamp: 998,
- sender_id: 5,
- nonce: Some(4.into()),
- },
- proto::ChannelMessage {
- id: 9,
- body: "z".into(),
- timestamp: 999,
- sender_id: 6,
- nonce: Some(5.into()),
- },
- ],
- },
- )
- .await;
-
- assert_eq!(
- channel.next_event(cx).await,
- ChannelEvent::MessagesUpdated {
- old_range: 0..0,
- new_count: 2,
- }
- );
- channel.read_with(cx, |channel, _| {
- assert_eq!(
- channel
- .messages_in_range(0..2)
- .map(|message| (message.sender.github_login.clone(), message.body.clone()))
- .collect::<Vec<_>>(),
- &[
- ("nathansobo".into(), "y".into()),
- ("maxbrunsfeld".into(), "z".into())
- ]
- );
- });
- }
-}
@@ -1,7 +1,6 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
-pub mod channel;
pub mod http;
pub mod telemetry;
pub mod user;
@@ -44,7 +43,6 @@ use thiserror::Error;
use url::Url;
use util::{ResultExt, TryFutureExt};
-pub use channel::*;
pub use rpc::*;
pub use user::*;
@@ -50,8 +50,9 @@ tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] }
[dependencies.sqlx]
-version = "0.6"
-features = ["runtime-tokio-rustls", "postgres", "time", "uuid"]
+git = "https://github.com/launchbadge/sqlx"
+rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
+features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"]
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
@@ -78,5 +79,10 @@ lazy_static = "1.4"
serde_json = { version = "1.0", features = ["preserve_order"] }
unindent = "0.1"
+[dev-dependencies.sqlx]
+git = "https://github.com/launchbadge/sqlx"
+rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
+features = ["sqlite"]
+
[features]
seed-support = ["clap", "lipsum", "reqwest"]
@@ -0,0 +1,41 @@
+CREATE TABLE IF NOT EXISTS "users" (
+ "id" INTEGER PRIMARY KEY,
+ "github_login" VARCHAR,
+ "admin" BOOLEAN,
+ "email_address" VARCHAR(255) DEFAULT NULL,
+ "invite_code" VARCHAR(64),
+ "invite_count" INTEGER NOT NULL DEFAULT 0,
+ "inviter_id" INTEGER REFERENCES users (id),
+ "connected_once" BOOLEAN NOT NULL DEFAULT false,
+ "created_at" TIMESTAMP NOT NULL DEFAULT now,
+ "metrics_id" VARCHAR(255),
+ "github_user_id" INTEGER
+);
+CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
+CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");
+CREATE INDEX "index_users_on_email_address" ON "users" ("email_address");
+CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id");
+
+CREATE TABLE IF NOT EXISTS "access_tokens" (
+ "id" INTEGER PRIMARY KEY,
+ "user_id" INTEGER REFERENCES users (id),
+ "hash" VARCHAR(128)
+);
+CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id");
+
+CREATE TABLE IF NOT EXISTS "contacts" (
+ "id" INTEGER PRIMARY KEY,
+ "user_id_a" INTEGER REFERENCES users (id) NOT NULL,
+ "user_id_b" INTEGER REFERENCES users (id) NOT NULL,
+ "a_to_b" BOOLEAN NOT NULL,
+ "should_notify" BOOLEAN NOT NULL,
+ "accepted" BOOLEAN NOT NULL
+);
+CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b");
+CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");
+
+CREATE TABLE IF NOT EXISTS "projects" (
+ "id" INTEGER PRIMARY KEY,
+ "host_user_id" INTEGER REFERENCES users (id) NOT NULL,
+ "unregistered" BOOLEAN NOT NULL DEFAULT false
+);
@@ -1,6 +1,6 @@
use crate::{
auth,
- db::{Invite, NewUserParams, ProjectId, Signup, User, UserId, WaitlistSummary},
+ db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary},
rpc::{self, ResultExt},
AppState, Error, Result,
};
@@ -16,9 +16,7 @@ use axum::{
};
use axum_extra::response::ErasedJson;
use serde::{Deserialize, Serialize};
-use serde_json::json;
-use std::{sync::Arc, time::Duration};
-use time::OffsetDateTime;
+use std::sync::Arc;
use tower::ServiceBuilder;
use tracing::instrument;
@@ -32,16 +30,6 @@ pub fn routes(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body
.route("/invite_codes/:code", get(get_user_for_invite_code))
.route("/panic", post(trace_panic))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
- .route(
- "/user_activity/summary",
- get(get_top_users_activity_summary),
- )
- .route(
- "/user_activity/timeline/:user_id",
- get(get_user_activity_timeline),
- )
- .route("/user_activity/counts", get(get_active_user_counts))
- .route("/project_metadata", get(get_project_metadata))
.route("/signups", post(create_signup))
.route("/signups_summary", get(get_waitlist_summary))
.route("/user_invites", post(create_invite_from_code))
@@ -283,93 +271,6 @@ async fn get_rpc_server_snapshot(
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
}
-#[derive(Deserialize)]
-struct TimePeriodParams {
- #[serde(with = "time::serde::iso8601")]
- start: OffsetDateTime,
- #[serde(with = "time::serde::iso8601")]
- end: OffsetDateTime,
-}
-
-async fn get_top_users_activity_summary(
- Query(params): Query<TimePeriodParams>,
- Extension(app): Extension<Arc<AppState>>,
-) -> Result<ErasedJson> {
- let summary = app
- .db
- .get_top_users_activity_summary(params.start..params.end, 100)
- .await?;
- Ok(ErasedJson::pretty(summary))
-}
-
-async fn get_user_activity_timeline(
- Path(user_id): Path<i32>,
- Query(params): Query<TimePeriodParams>,
- Extension(app): Extension<Arc<AppState>>,
-) -> Result<ErasedJson> {
- let summary = app
- .db
- .get_user_activity_timeline(params.start..params.end, UserId(user_id))
- .await?;
- Ok(ErasedJson::pretty(summary))
-}
-
-#[derive(Deserialize)]
-struct ActiveUserCountParams {
- #[serde(flatten)]
- period: TimePeriodParams,
- durations_in_minutes: String,
- #[serde(default)]
- only_collaborative: bool,
-}
-
-#[derive(Serialize)]
-struct ActiveUserSet {
- active_time_in_minutes: u64,
- user_count: usize,
-}
-
-async fn get_active_user_counts(
- Query(params): Query<ActiveUserCountParams>,
- Extension(app): Extension<Arc<AppState>>,
-) -> Result<ErasedJson> {
- let durations_in_minutes = params.durations_in_minutes.split(',');
- let mut user_sets = Vec::new();
- for duration in durations_in_minutes {
- let duration = duration
- .parse()
- .map_err(|_| anyhow!("invalid duration: {duration}"))?;
- user_sets.push(ActiveUserSet {
- active_time_in_minutes: duration,
- user_count: app
- .db
- .get_active_user_count(
- params.period.start..params.period.end,
- Duration::from_secs(duration * 60),
- params.only_collaborative,
- )
- .await?,
- })
- }
- Ok(ErasedJson::pretty(user_sets))
-}
-
-#[derive(Deserialize)]
-struct GetProjectMetadataParams {
- project_id: u64,
-}
-
-async fn get_project_metadata(
- Query(params): Query<GetProjectMetadataParams>,
- Extension(app): Extension<Arc<AppState>>,
-) -> Result<ErasedJson> {
- let extensions = app
- .db
- .get_project_extensions(ProjectId::from_proto(params.project_id))
- .await?;
- Ok(ErasedJson::pretty(json!({ "extensions": extensions })))
-}
-
#[derive(Deserialize)]
struct CreateAccessTokenQueryParams {
public_key: String,
@@ -75,7 +75,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
-pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result<String> {
+pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result<String> {
let access_token = rpc::auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
@@ -1,1609 +1,1192 @@
use crate::{Error, Result};
-use anyhow::{anyhow, Context};
-use async_trait::async_trait;
+use anyhow::anyhow;
use axum::http::StatusCode;
use collections::HashMap;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
-pub use sqlx::postgres::PgPoolOptions as DbOptions;
use sqlx::{
migrate::{Migrate as _, Migration, MigrationSource},
types::Uuid,
- FromRow, QueryBuilder,
+ FromRow,
};
-use std::{cmp, ops::Range, path::Path, time::Duration};
+use std::{path::Path, time::Duration};
use time::{OffsetDateTime, PrimitiveDateTime};
-#[async_trait]
-pub trait Db: Send + Sync {
- async fn create_user(
- &self,
- email_address: &str,
- admin: bool,
- params: NewUserParams,
- ) -> Result<NewUserResult>;
- async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
- async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
- async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
- async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
- async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
- async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
- async fn get_user_by_github_account(
- &self,
- github_login: &str,
- github_user_id: Option<i32>,
- ) -> Result<Option<User>>;
- async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
- async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
- async fn destroy_user(&self, id: UserId) -> Result<()>;
-
- async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
- async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
- async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
- async fn create_invite_from_code(
- &self,
- code: &str,
- email_address: &str,
- device_id: Option<&str>,
- ) -> Result<Invite>;
-
- async fn create_signup(&self, signup: Signup) -> Result<()>;
- async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
- async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
- async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
- async fn create_user_from_invite(
- &self,
- invite: &Invite,
- user: NewUserParams,
- ) -> Result<Option<NewUserResult>>;
-
- /// Registers a new project for the given user.
- async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
-
- /// Unregisters a project for the given project id.
- async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
-
- /// Update file counts by extension for the given project and worktree.
- async fn update_worktree_extensions(
- &self,
- project_id: ProjectId,
- worktree_id: u64,
- extensions: HashMap<String, u32>,
- ) -> Result<()>;
-
- /// Get the file counts on the given project keyed by their worktree and extension.
- async fn get_project_extensions(
- &self,
- project_id: ProjectId,
- ) -> Result<HashMap<u64, HashMap<String, usize>>>;
-
- /// Record which users have been active in which projects during
- /// a given period of time.
- async fn record_user_activity(
- &self,
- time_period: Range<OffsetDateTime>,
- active_projects: &[(UserId, ProjectId)],
- ) -> Result<()>;
+#[cfg(test)]
+pub type DefaultDb = Db<sqlx::Sqlite>;
- /// Get the number of users who have been active in the given
- /// time period for at least the given time duration.
- async fn get_active_user_count(
- &self,
- time_period: Range<OffsetDateTime>,
- min_duration: Duration,
- only_collaborative: bool,
- ) -> Result<usize>;
-
- /// Get the users that have been most active during the given time period,
- /// along with the amount of time they have been active in each project.
- async fn get_top_users_activity_summary(
- &self,
- time_period: Range<OffsetDateTime>,
- max_user_count: usize,
- ) -> Result<Vec<UserActivitySummary>>;
+#[cfg(not(test))]
+pub type DefaultDb = Db<sqlx::Postgres>;
- /// Get the project activity for the given user and time period.
- async fn get_user_activity_timeline(
- &self,
- time_period: Range<OffsetDateTime>,
- user_id: UserId,
- ) -> Result<Vec<UserActivityPeriod>>;
+pub struct Db<D: sqlx::Database> {
+ pool: sqlx::Pool<D>,
+ #[cfg(test)]
+ background: Option<std::sync::Arc<gpui::executor::Background>>,
+ #[cfg(test)]
+ runtime: Option<tokio::runtime::Runtime>,
+}
- async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
- async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
- async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
- async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
- async fn dismiss_contact_notification(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- ) -> Result<()>;
- async fn respond_to_contact_request(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- accept: bool,
- ) -> Result<()>;
+macro_rules! test_support {
+ ($self:ident, { $($token:tt)* }) => {{
+ let body = async {
+ $($token)*
+ };
- async fn create_access_token_hash(
- &self,
- user_id: UserId,
- access_token_hash: &str,
- max_access_token_count: usize,
- ) -> Result<()>;
- async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
-
- #[cfg(any(test, feature = "seed-support"))]
- async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
- #[cfg(any(test, feature = "seed-support"))]
-
- async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
- async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
- async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
- -> Result<bool>;
-
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_channel_member(
- &self,
- channel_id: ChannelId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()>;
- async fn create_channel_message(
- &self,
- channel_id: ChannelId,
- sender_id: UserId,
- body: &str,
- timestamp: OffsetDateTime,
- nonce: u128,
- ) -> Result<MessageId>;
- async fn get_channel_messages(
- &self,
- channel_id: ChannelId,
- count: usize,
- before_id: Option<MessageId>,
- ) -> Result<Vec<ChannelMessage>>;
+ if cfg!(test) {
+ #[cfg(not(test))]
+ unreachable!();
- #[cfg(test)]
- async fn teardown(&self, url: &str);
+ #[cfg(test)]
+ if let Some(background) = $self.background.as_ref() {
+ background.simulate_random_delay().await;
+ }
- #[cfg(test)]
- fn as_fake(&self) -> Option<&FakeDb>;
+ #[cfg(test)]
+ $self.runtime.as_ref().unwrap().block_on(body)
+ } else {
+ body.await
+ }
+ }};
}
-#[cfg(any(test, debug_assertions))]
-pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
- Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
+pub trait RowsAffected {
+ fn rows_affected(&self) -> u64;
+}
-#[cfg(not(any(test, debug_assertions)))]
-pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
+#[cfg(test)]
+impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
+ fn rows_affected(&self) -> u64 {
+ self.rows_affected()
+ }
+}
-pub struct PostgresDb {
- pool: sqlx::PgPool,
+impl RowsAffected for sqlx::postgres::PgQueryResult {
+ fn rows_affected(&self) -> u64 {
+ self.rows_affected()
+ }
}
-impl PostgresDb {
+#[cfg(test)]
+impl Db<sqlx::Sqlite> {
pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
- let pool = DbOptions::new()
+ use std::str::FromStr as _;
+ let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
+ .unwrap()
+ .create_if_missing(true)
+ .shared_cache(true);
+ let pool = sqlx::sqlite::SqlitePoolOptions::new()
+ .min_connections(2)
.max_connections(max_connections)
- .connect(url)
- .await
- .context("failed to connect to postgres database")?;
- Ok(Self { pool })
+ .connect_with(options)
+ .await?;
+ Ok(Self {
+ pool,
+ background: None,
+ runtime: None,
+ })
}
- pub async fn migrate(
- &self,
- migrations_path: &Path,
- ignore_checksum_mismatch: bool,
- ) -> anyhow::Result<Vec<(Migration, Duration)>> {
- let migrations = MigrationSource::resolve(migrations_path)
- .await
- .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
-
- let mut conn = self.pool.acquire().await?;
-
- conn.ensure_migrations_table().await?;
- let applied_migrations: HashMap<_, _> = conn
- .list_applied_migrations()
- .await?
- .into_iter()
- .map(|m| (m.version, m))
- .collect();
-
- let mut new_migrations = Vec::new();
- for migration in migrations {
- match applied_migrations.get(&migration.version) {
- Some(applied_migration) => {
- if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
- {
- Err(anyhow!(
- "checksum mismatch for applied migration {}",
- migration.description
- ))?;
- }
- }
- None => {
- let elapsed = conn.apply(&migration).await?;
- new_migrations.push((migration, elapsed));
- }
- }
- }
-
- Ok(new_migrations)
+ pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+ test_support!(self, {
+ let query = "
+ SELECT users.*
+ FROM users
+ WHERE users.id IN (SELECT value from json_each($1))
+ ";
+ Ok(sqlx::query_as(query)
+ .bind(&serde_json::json!(ids))
+ .fetch_all(&self.pool)
+ .await?)
+ })
}
- pub fn fuzzy_like_string(string: &str) -> String {
- let mut result = String::with_capacity(string.len() * 2 + 1);
- for c in string.chars() {
- if c.is_alphanumeric() {
- result.push('%');
- result.push(c);
- }
- }
- result.push('%');
- result
+ pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
+ test_support!(self, {
+ let query = "
+ SELECT metrics_id
+ FROM users
+ WHERE id = $1
+ ";
+ Ok(sqlx::query_scalar(query)
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await?)
+ })
}
-}
-
-#[async_trait]
-impl Db for PostgresDb {
- // users
- async fn create_user(
+ pub async fn create_user(
&self,
email_address: &str,
admin: bool,
params: NewUserParams,
) -> Result<NewUserResult> {
- let query = "
- INSERT INTO users (email_address, github_login, github_user_id, admin)
- VALUES ($1, $2, $3, $4)
- ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
- RETURNING id, metrics_id::text
- ";
- let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
- .bind(email_address)
- .bind(params.github_login)
- .bind(params.github_user_id)
- .bind(admin)
- .fetch_one(&self.pool)
- .await?;
- Ok(NewUserResult {
- user_id,
- metrics_id,
- signup_device_id: None,
- inviting_user_id: None,
- })
- }
+ test_support!(self, {
+ let query = "
+ INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
+ VALUES ($1, $2, $3, $4, $5)
+ ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
+ RETURNING id, metrics_id
+ ";
- async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
- let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
- Ok(sqlx::query_as(query)
- .bind(limit as i32)
- .bind((page * limit) as i32)
- .fetch_all(&self.pool)
- .await?)
+ let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
+ .bind(email_address)
+ .bind(params.github_login)
+ .bind(params.github_user_id)
+ .bind(admin)
+ .bind(Uuid::new_v4().to_string())
+ .fetch_one(&self.pool)
+ .await?;
+ Ok(NewUserResult {
+ user_id,
+ metrics_id,
+ signup_device_id: None,
+ inviting_user_id: None,
+ })
+ })
}
- async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
- let like_string = Self::fuzzy_like_string(name_query);
- let query = "
- SELECT users.*
- FROM users
- WHERE github_login ILIKE $1
- ORDER BY github_login <-> $2
- LIMIT $3
- ";
- Ok(sqlx::query_as(query)
- .bind(like_string)
- .bind(name_query)
- .bind(limit as i32)
- .fetch_all(&self.pool)
- .await?)
+ pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
+ unimplemented!()
}
- async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
- let users = self.get_users_by_ids(vec![id]).await?;
- Ok(users.into_iter().next())
+ pub async fn create_user_from_invite(
+ &self,
+ _invite: &Invite,
+ _user: NewUserParams,
+ ) -> Result<Option<NewUserResult>> {
+ unimplemented!()
}
- async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
- let query = "
- SELECT metrics_id::text
- FROM users
- WHERE id = $1
- ";
- Ok(sqlx::query_scalar(query)
- .bind(id)
- .fetch_one(&self.pool)
- .await?)
+ pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
+ unimplemented!()
}
- async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
- let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
- let query = "
- SELECT users.*
- FROM users
- WHERE users.id = ANY ($1)
- ";
- Ok(sqlx::query_as(query)
- .bind(&ids)
- .fetch_all(&self.pool)
- .await?)
+ pub async fn create_invite_from_code(
+ &self,
+ _code: &str,
+ _email_address: &str,
+ _device_id: Option<&str>,
+ ) -> Result<Invite> {
+ unimplemented!()
}
- async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
- let query = format!(
- "
- SELECT users.*
- FROM users
- WHERE invite_count = 0
- AND inviter_id IS{} NULL
- ",
- if invited_by_another_user { " NOT" } else { "" }
- );
-
- Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
+ pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
+ unimplemented!()
}
+}
- async fn get_user_by_github_account(
- &self,
- github_login: &str,
- github_user_id: Option<i32>,
- ) -> Result<Option<User>> {
- if let Some(github_user_id) = github_user_id {
- let mut user = sqlx::query_as::<_, User>(
- "
- UPDATE users
- SET github_login = $1
- WHERE github_user_id = $2
- RETURNING *
- ",
- )
- .bind(github_login)
- .bind(github_user_id)
- .fetch_optional(&self.pool)
+impl Db<sqlx::Postgres> {
+ pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
+ let pool = sqlx::postgres::PgPoolOptions::new()
+ .max_connections(max_connections)
+ .connect(url)
.await?;
-
- if user.is_none() {
- user = sqlx::query_as::<_, User>(
- "
- UPDATE users
- SET github_user_id = $1
- WHERE github_login = $2
- RETURNING *
- ",
- )
- .bind(github_user_id)
- .bind(github_login)
- .fetch_optional(&self.pool)
- .await?;
- }
-
- Ok(user)
- } else {
- Ok(sqlx::query_as(
- "
- SELECT * FROM users
- WHERE github_login = $1
- LIMIT 1
- ",
- )
- .bind(github_login)
- .fetch_optional(&self.pool)
- .await?)
- }
- }
-
- async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
- let query = "UPDATE users SET admin = $1 WHERE id = $2";
- Ok(sqlx::query(query)
- .bind(is_admin)
- .bind(id.0)
- .execute(&self.pool)
- .await
- .map(drop)?)
+ Ok(Self {
+ pool,
+ #[cfg(test)]
+ background: None,
+ #[cfg(test)]
+ runtime: None,
+ })
}
- async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
- let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
- Ok(sqlx::query(query)
- .bind(connected_once)
- .bind(id.0)
- .execute(&self.pool)
- .await
- .map(drop)?)
+ #[cfg(test)]
+ pub fn teardown(&self, url: &str) {
+ self.runtime.as_ref().unwrap().block_on(async {
+ use util::ResultExt;
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
+ ";
+ sqlx::query(query).execute(&self.pool).await.log_err();
+ self.pool.close().await;
+ <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
+ .await
+ .log_err();
+ })
}
- async fn destroy_user(&self, id: UserId) -> Result<()> {
- let query = "DELETE FROM access_tokens WHERE user_id = $1;";
- sqlx::query(query)
- .bind(id.0)
- .execute(&self.pool)
- .await
- .map(drop)?;
- let query = "DELETE FROM users WHERE id = $1;";
- Ok(sqlx::query(query)
- .bind(id.0)
- .execute(&self.pool)
- .await
- .map(drop)?)
+ pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
+ test_support!(self, {
+ let like_string = Self::fuzzy_like_string(name_query);
+ let query = "
+ SELECT users.*
+ FROM users
+ WHERE github_login ILIKE $1
+ ORDER BY github_login <-> $2
+ LIMIT $3
+ ";
+ Ok(sqlx::query_as(query)
+ .bind(like_string)
+ .bind(name_query)
+ .bind(limit as i32)
+ .fetch_all(&self.pool)
+ .await?)
+ })
}
- // signups
-
- async fn create_signup(&self, signup: Signup) -> Result<()> {
- sqlx::query(
- "
- INSERT INTO signups
- (
- email_address,
- email_confirmation_code,
- email_confirmation_sent,
- platform_linux,
- platform_mac,
- platform_windows,
- platform_unknown,
- editor_features,
- programming_languages,
- device_id
- )
- VALUES
- ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
- RETURNING id
- ",
- )
- .bind(&signup.email_address)
- .bind(&random_email_confirmation_code())
- .bind(&signup.platform_linux)
- .bind(&signup.platform_mac)
- .bind(&signup.platform_windows)
- .bind(&signup.editor_features)
- .bind(&signup.programming_languages)
- .bind(&signup.device_id)
- .execute(&self.pool)
- .await?;
- Ok(())
+ pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+ test_support!(self, {
+ let query = "
+ SELECT users.*
+ FROM users
+ WHERE users.id = ANY ($1)
+ ";
+ Ok(sqlx::query_as(query)
+ .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
+ .fetch_all(&self.pool)
+ .await?)
+ })
}
- async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
- Ok(sqlx::query_as(
- "
- SELECT
- COUNT(*) as count,
- COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
- COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
- COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
- COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
- FROM (
- SELECT *
- FROM signups
- WHERE
- NOT email_confirmation_sent
- ) AS unsent
- ",
- )
- .fetch_one(&self.pool)
- .await?)
+ pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
+ test_support!(self, {
+ let query = "
+ SELECT metrics_id::text
+ FROM users
+ WHERE id = $1
+ ";
+ Ok(sqlx::query_scalar(query)
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await?)
+ })
}
- async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
- Ok(sqlx::query_as(
- "
- SELECT
- email_address, email_confirmation_code
- FROM signups
- WHERE
- NOT email_confirmation_sent AND
- (platform_mac OR platform_unknown)
- LIMIT $1
- ",
- )
- .bind(count as i32)
- .fetch_all(&self.pool)
- .await?)
- }
+ pub async fn create_user(
+ &self,
+ email_address: &str,
+ admin: bool,
+ params: NewUserParams,
+ ) -> Result<NewUserResult> {
+ test_support!(self, {
+ let query = "
+ INSERT INTO users (email_address, github_login, github_user_id, admin)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
+ RETURNING id, metrics_id::text
+ ";
- async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
- sqlx::query(
- "
- UPDATE signups
- SET email_confirmation_sent = 't'
- WHERE email_address = ANY ($1)
- ",
- )
- .bind(
- &invites
- .iter()
- .map(|s| s.email_address.as_str())
- .collect::<Vec<_>>(),
- )
- .execute(&self.pool)
- .await?;
- Ok(())
+ let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
+ .bind(email_address)
+ .bind(params.github_login)
+ .bind(params.github_user_id)
+ .bind(admin)
+ .fetch_one(&self.pool)
+ .await?;
+ Ok(NewUserResult {
+ user_id,
+ metrics_id,
+ signup_device_id: None,
+ inviting_user_id: None,
+ })
+ })
}
- async fn create_user_from_invite(
+ pub async fn create_user_from_invite(
&self,
invite: &Invite,
user: NewUserParams,
) -> Result<Option<NewUserResult>> {
- let mut tx = self.pool.begin().await?;
-
- let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
- i32,
- Option<UserId>,
- Option<UserId>,
- Option<String>,
- ) = sqlx::query_as(
- "
- SELECT id, user_id, inviting_user_id, device_id
- FROM signups
- WHERE
- email_address = $1 AND
- email_confirmation_code = $2
- ",
- )
- .bind(&invite.email_address)
- .bind(&invite.email_confirmation_code)
- .fetch_optional(&mut tx)
- .await?
- .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
-
- if existing_user_id.is_some() {
- return Ok(None);
- }
+ test_support!(self, {
+ let mut tx = self.pool.begin().await?;
- let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
- "
- INSERT INTO users
- (email_address, github_login, github_user_id, admin, invite_count, invite_code)
- VALUES
- ($1, $2, $3, 'f', $4, $5)
- ON CONFLICT (github_login) DO UPDATE SET
- email_address = excluded.email_address,
- github_user_id = excluded.github_user_id,
- admin = excluded.admin
- RETURNING id, metrics_id::text
- ",
- )
- .bind(&invite.email_address)
- .bind(&user.github_login)
- .bind(&user.github_user_id)
- .bind(&user.invite_count)
- .bind(random_invite_code())
- .fetch_one(&mut tx)
- .await?;
-
- sqlx::query(
- "
- UPDATE signups
- SET user_id = $1
- WHERE id = $2
- ",
- )
- .bind(&user_id)
- .bind(&signup_id)
- .execute(&mut tx)
- .await?;
-
- if let Some(inviting_user_id) = inviting_user_id {
- let id: Option<UserId> = sqlx::query_scalar(
+ let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
+ i32,
+ Option<UserId>,
+ Option<UserId>,
+ Option<String>,
+ ) = sqlx::query_as(
"
- UPDATE users
- SET invite_count = invite_count - 1
- WHERE id = $1 AND invite_count > 0
- RETURNING id
+ SELECT id, user_id, inviting_user_id, device_id
+ FROM signups
+ WHERE
+ email_address = $1 AND
+ email_confirmation_code = $2
",
)
- .bind(&inviting_user_id)
+ .bind(&invite.email_address)
+ .bind(&invite.email_confirmation_code)
.fetch_optional(&mut tx)
- .await?;
+ .await?
+ .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
- if id.is_none() {
- Err(Error::Http(
- StatusCode::UNAUTHORIZED,
- "no invites remaining".to_string(),
- ))?;
+ if existing_user_id.is_some() {
+ return Ok(None);
}
- sqlx::query(
+ let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
"
- INSERT INTO contacts
- (user_id_a, user_id_b, a_to_b, should_notify, accepted)
+ INSERT INTO users
+ (email_address, github_login, github_user_id, admin, invite_count, invite_code)
VALUES
- ($1, $2, 't', 't', 't')
- ON CONFLICT DO NOTHING
+ ($1, $2, $3, FALSE, $4, $5)
+ ON CONFLICT (github_login) DO UPDATE SET
+ email_address = excluded.email_address,
+ github_user_id = excluded.github_user_id,
+ admin = excluded.admin
+ RETURNING id, metrics_id::text
",
)
- .bind(inviting_user_id)
- .bind(user_id)
- .execute(&mut tx)
+ .bind(&invite.email_address)
+ .bind(&user.github_login)
+ .bind(&user.github_user_id)
+ .bind(&user.invite_count)
+ .bind(random_invite_code())
+ .fetch_one(&mut tx)
.await?;
- }
-
- tx.commit().await?;
- Ok(Some(NewUserResult {
- user_id,
- metrics_id,
- inviting_user_id,
- signup_device_id,
- }))
- }
- // invite codes
-
- async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
- let mut tx = self.pool.begin().await?;
- if count > 0 {
sqlx::query(
"
- UPDATE users
- SET invite_code = $1
- WHERE id = $2 AND invite_code IS NULL
- ",
+ UPDATE signups
+ SET user_id = $1
+ WHERE id = $2
+ ",
)
- .bind(random_invite_code())
- .bind(id)
+ .bind(&user_id)
+ .bind(&signup_id)
.execute(&mut tx)
.await?;
- }
- sqlx::query(
- "
- UPDATE users
- SET invite_count = $1
- WHERE id = $2
- ",
- )
- .bind(count as i32)
- .bind(id)
- .execute(&mut tx)
- .await?;
- tx.commit().await?;
- Ok(())
- }
+ if let Some(inviting_user_id) = inviting_user_id {
+ let id: Option<UserId> = sqlx::query_scalar(
+ "
+ UPDATE users
+ SET invite_count = invite_count - 1
+ WHERE id = $1 AND invite_count > 0
+ RETURNING id
+ ",
+ )
+ .bind(&inviting_user_id)
+ .fetch_optional(&mut tx)
+ .await?;
- async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
- let result: Option<(String, i32)> = sqlx::query_as(
- "
- SELECT invite_code, invite_count
- FROM users
- WHERE id = $1 AND invite_code IS NOT NULL
- ",
- )
- .bind(id)
- .fetch_optional(&self.pool)
- .await?;
- if let Some((code, count)) = result {
- Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
- } else {
- Ok(None)
- }
+ if id.is_none() {
+ Err(Error::Http(
+ StatusCode::UNAUTHORIZED,
+ "no invites remaining".to_string(),
+ ))?;
+ }
+
+ sqlx::query(
+ "
+ INSERT INTO contacts
+ (user_id_a, user_id_b, a_to_b, should_notify, accepted)
+ VALUES
+ ($1, $2, TRUE, TRUE, TRUE)
+ ON CONFLICT DO NOTHING
+ ",
+ )
+ .bind(inviting_user_id)
+ .bind(user_id)
+ .execute(&mut tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+ Ok(Some(NewUserResult {
+ user_id,
+ metrics_id,
+ inviting_user_id,
+ signup_device_id,
+ }))
+ })
}
- async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
- sqlx::query_as(
- "
- SELECT *
- FROM users
- WHERE invite_code = $1
- ",
- )
- .bind(code)
- .fetch_optional(&self.pool)
- .await?
- .ok_or_else(|| {
- Error::Http(
- StatusCode::NOT_FOUND,
- "that invite code does not exist".to_string(),
+ pub async fn create_signup(&self, signup: Signup) -> Result<()> {
+ test_support!(self, {
+ sqlx::query(
+ "
+ INSERT INTO signups
+ (
+ email_address,
+ email_confirmation_code,
+ email_confirmation_sent,
+ platform_linux,
+ platform_mac,
+ platform_windows,
+ platform_unknown,
+ editor_features,
+ programming_languages,
+ device_id
+ )
+ VALUES
+ ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
+ RETURNING id
+ ",
)
+ .bind(&signup.email_address)
+ .bind(&random_email_confirmation_code())
+ .bind(&signup.platform_linux)
+ .bind(&signup.platform_mac)
+ .bind(&signup.platform_windows)
+ .bind(&signup.editor_features)
+ .bind(&signup.programming_languages)
+ .bind(&signup.device_id)
+ .execute(&self.pool)
+ .await?;
+ Ok(())
})
}
- async fn create_invite_from_code(
+ pub async fn create_invite_from_code(
&self,
code: &str,
email_address: &str,
device_id: Option<&str>,
) -> Result<Invite> {
- let mut tx = self.pool.begin().await?;
-
- let existing_user: Option<UserId> = sqlx::query_scalar(
- "
- SELECT id
- FROM users
- WHERE email_address = $1
- ",
- )
- .bind(email_address)
- .fetch_optional(&mut tx)
- .await?;
- if existing_user.is_some() {
- Err(anyhow!("email address is already in use"))?;
- }
+ test_support!(self, {
+ let mut tx = self.pool.begin().await?;
- let row: Option<(UserId, i32)> = sqlx::query_as(
- "
- SELECT id, invite_count
- FROM users
- WHERE invite_code = $1
- ",
- )
- .bind(code)
- .fetch_optional(&mut tx)
- .await?;
-
- let (inviter_id, invite_count) = match row {
- Some(row) => row,
- None => Err(Error::Http(
- StatusCode::NOT_FOUND,
- "invite code not found".to_string(),
- ))?,
- };
+ let existing_user: Option<UserId> = sqlx::query_scalar(
+ "
+ SELECT id
+ FROM users
+ WHERE email_address = $1
+ ",
+ )
+ .bind(email_address)
+ .fetch_optional(&mut tx)
+ .await?;
+ if existing_user.is_some() {
+ Err(anyhow!("email address is already in use"))?;
+ }
- if invite_count == 0 {
- Err(Error::Http(
- StatusCode::UNAUTHORIZED,
- "no invites remaining".to_string(),
- ))?;
- }
+ let row: Option<(UserId, i32)> = sqlx::query_as(
+ "
+ SELECT id, invite_count
+ FROM users
+ WHERE invite_code = $1
+ ",
+ )
+ .bind(code)
+ .fetch_optional(&mut tx)
+ .await?;
- let email_confirmation_code: String = sqlx::query_scalar(
- "
- INSERT INTO signups
- (
- email_address,
- email_confirmation_code,
- email_confirmation_sent,
- inviting_user_id,
- platform_linux,
- platform_mac,
- platform_windows,
- platform_unknown,
- device_id
+ let (inviter_id, invite_count) = match row {
+ Some(row) => row,
+ None => Err(Error::Http(
+ StatusCode::NOT_FOUND,
+ "invite code not found".to_string(),
+ ))?,
+ };
+
+ if invite_count == 0 {
+ Err(Error::Http(
+ StatusCode::UNAUTHORIZED,
+ "no invites remaining".to_string(),
+ ))?;
+ }
+
+ let email_confirmation_code: String = sqlx::query_scalar(
+ "
+ INSERT INTO signups
+ (
+ email_address,
+ email_confirmation_code,
+ email_confirmation_sent,
+ inviting_user_id,
+ platform_linux,
+ platform_mac,
+ platform_windows,
+ platform_unknown,
+ device_id
+ )
+ VALUES
+ ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
+ ON CONFLICT (email_address)
+ DO UPDATE SET
+ inviting_user_id = excluded.inviting_user_id
+ RETURNING email_confirmation_code
+ ",
)
- VALUES
- ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
- ON CONFLICT (email_address)
- DO UPDATE SET
- inviting_user_id = excluded.inviting_user_id
- RETURNING email_confirmation_code
- ",
- )
- .bind(&email_address)
- .bind(&random_email_confirmation_code())
- .bind(&inviter_id)
- .bind(&device_id)
- .fetch_one(&mut tx)
- .await?;
-
- tx.commit().await?;
-
- Ok(Invite {
- email_address: email_address.into(),
- email_confirmation_code,
- })
- }
+ .bind(&email_address)
+ .bind(&random_email_confirmation_code())
+ .bind(&inviter_id)
+ .bind(&device_id)
+ .fetch_one(&mut tx)
+ .await?;
- // projects
+ tx.commit().await?;
- async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
- Ok(sqlx::query_scalar(
- "
- INSERT INTO projects(host_user_id)
- VALUES ($1)
- RETURNING id
- ",
- )
- .bind(host_user_id)
- .fetch_one(&self.pool)
- .await
- .map(ProjectId)?)
+ Ok(Invite {
+ email_address: email_address.into(),
+ email_confirmation_code,
+ })
+ })
}
- async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
- sqlx::query(
- "
- UPDATE projects
- SET unregistered = 't'
- WHERE id = $1
- ",
- )
- .bind(project_id)
- .execute(&self.pool)
- .await?;
- Ok(())
+ pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
+ test_support!(self, {
+ let emails = invites
+ .iter()
+ .map(|s| s.email_address.as_str())
+ .collect::<Vec<_>>();
+ sqlx::query(
+ "
+ UPDATE signups
+ SET email_confirmation_sent = TRUE
+ WHERE email_address = ANY ($1)
+ ",
+ )
+ .bind(&emails)
+ .execute(&self.pool)
+ .await?;
+ Ok(())
+ })
}
+}
- async fn update_worktree_extensions(
+impl<D> Db<D>
+where
+ D: sqlx::Database + sqlx::migrate::MigrateDatabase,
+ D::Connection: sqlx::migrate::Migrate,
+ for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
+ for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
+ for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
+ D::QueryResult: RowsAffected,
+ String: sqlx::Type<D>,
+ i32: sqlx::Type<D>,
+ i64: sqlx::Type<D>,
+ bool: sqlx::Type<D>,
+ str: sqlx::Type<D>,
+ Uuid: sqlx::Type<D>,
+ sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
+ OffsetDateTime: sqlx::Type<D>,
+ PrimitiveDateTime: sqlx::Type<D>,
+ usize: sqlx::ColumnIndex<D::Row>,
+ for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
+ for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+ for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
+{
+ pub async fn migrate(
&self,
- project_id: ProjectId,
- worktree_id: u64,
- extensions: HashMap<String, u32>,
- ) -> Result<()> {
- if extensions.is_empty() {
- return Ok(());
- }
+ migrations_path: &Path,
+ ignore_checksum_mismatch: bool,
+ ) -> anyhow::Result<Vec<(Migration, Duration)>> {
+ let migrations = MigrationSource::resolve(migrations_path)
+ .await
+ .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
- let mut query = QueryBuilder::new(
- "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
- );
- query.push_values(extensions, |mut query, (extension, count)| {
- query
- .push_bind(project_id)
- .push_bind(worktree_id as i32)
- .push_bind(extension)
- .push_bind(count as i32);
- });
- query.push(
- "
- ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
- count = excluded.count
- ",
- );
- query.build().execute(&self.pool).await?;
-
- Ok(())
- }
+ let mut conn = self.pool.acquire().await?;
- async fn get_project_extensions(
- &self,
- project_id: ProjectId,
- ) -> Result<HashMap<u64, HashMap<String, usize>>> {
- #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
- struct WorktreeExtension {
- worktree_id: i32,
- extension: String,
- count: i32,
+ conn.ensure_migrations_table().await?;
+ let applied_migrations: HashMap<_, _> = conn
+ .list_applied_migrations()
+ .await?
+ .into_iter()
+ .map(|m| (m.version, m))
+ .collect();
+
+ let mut new_migrations = Vec::new();
+ for migration in migrations {
+ match applied_migrations.get(&migration.version) {
+ Some(applied_migration) => {
+ if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
+ {
+ Err(anyhow!(
+ "checksum mismatch for applied migration {}",
+ migration.description
+ ))?;
+ }
+ }
+ None => {
+ let elapsed = conn.apply(&migration).await?;
+ new_migrations.push((migration, elapsed));
+ }
+ }
}
- let query = "
- SELECT worktree_id, extension, count
- FROM worktree_extensions
- WHERE project_id = $1
- ";
- let counts = sqlx::query_as::<_, WorktreeExtension>(query)
- .bind(&project_id)
- .fetch_all(&self.pool)
- .await?;
+ Ok(new_migrations)
+ }
- let mut extension_counts = HashMap::default();
- for count in counts {
- extension_counts
- .entry(count.worktree_id as u64)
- .or_insert_with(HashMap::default)
- .insert(count.extension, count.count as usize);
+ pub fn fuzzy_like_string(string: &str) -> String {
+ let mut result = String::with_capacity(string.len() * 2 + 1);
+ for c in string.chars() {
+ if c.is_alphanumeric() {
+ result.push('%');
+ result.push(c);
+ }
}
- Ok(extension_counts)
+ result.push('%');
+ result
}
- async fn record_user_activity(
- &self,
- time_period: Range<OffsetDateTime>,
- projects: &[(UserId, ProjectId)],
- ) -> Result<()> {
- let query = "
- INSERT INTO project_activity_periods
- (ended_at, duration_millis, user_id, project_id)
- VALUES
- ($1, $2, $3, $4);
- ";
-
- let mut tx = self.pool.begin().await?;
- let duration_millis =
- ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
- for (user_id, project_id) in projects {
- sqlx::query(query)
- .bind(time_period.end)
- .bind(duration_millis)
- .bind(user_id)
- .bind(project_id)
- .execute(&mut tx)
- .await?;
- }
- tx.commit().await?;
+ // users
- Ok(())
+ pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
+ test_support!(self, {
+ let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
+ Ok(sqlx::query_as(query)
+ .bind(limit as i32)
+ .bind((page * limit) as i32)
+ .fetch_all(&self.pool)
+ .await?)
+ })
}
- async fn get_active_user_count(
+ pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ test_support!(self, {
+ let query = "
+ SELECT users.*
+ FROM users
+ WHERE id = $1
+ LIMIT 1
+ ";
+ Ok(sqlx::query_as(query)
+ .bind(&id)
+ .fetch_optional(&self.pool)
+ .await?)
+ })
+ }
+
+ pub async fn get_users_with_no_invites(
&self,
- time_period: Range<OffsetDateTime>,
- min_duration: Duration,
- only_collaborative: bool,
- ) -> Result<usize> {
- let mut with_clause = String::new();
- with_clause.push_str("WITH\n");
- with_clause.push_str(
- "
- project_durations AS (
- SELECT user_id, project_id, SUM(duration_millis) AS project_duration
- FROM project_activity_periods
- WHERE $1 < ended_at AND ended_at <= $2
- GROUP BY user_id, project_id
- ),
- ",
- );
- with_clause.push_str(
- "
- project_collaborators as (
- SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
- FROM project_durations
- GROUP BY project_id
- ),
- ",
- );
-
- if only_collaborative {
- with_clause.push_str(
- "
- user_durations AS (
- SELECT user_id, SUM(project_duration) as total_duration
- FROM project_durations, project_collaborators
- WHERE
- project_durations.project_id = project_collaborators.project_id AND
- max_collaborators > 1
- GROUP BY user_id
- ORDER BY total_duration DESC
- LIMIT $3
- )
- ",
- );
- } else {
- with_clause.push_str(
+ invited_by_another_user: bool,
+ ) -> Result<Vec<User>> {
+ test_support!(self, {
+ let query = format!(
"
- user_durations AS (
- SELECT user_id, SUM(project_duration) as total_duration
- FROM project_durations
- GROUP BY user_id
- ORDER BY total_duration DESC
- LIMIT $3
- )
+ SELECT users.*
+ FROM users
+ WHERE invite_count = 0
+ AND inviter_id IS{} NULL
",
+ if invited_by_another_user { " NOT" } else { "" }
);
- }
- let query = format!(
- "
- {with_clause}
- SELECT count(user_durations.user_id)
- FROM user_durations
- WHERE user_durations.total_duration >= $3
- "
- );
-
- let count: i64 = sqlx::query_scalar(&query)
- .bind(time_period.start)
- .bind(time_period.end)
- .bind(min_duration.as_millis() as i64)
- .fetch_one(&self.pool)
- .await?;
- Ok(count as usize)
+ Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
+ })
}
- async fn get_top_users_activity_summary(
+ pub async fn get_user_by_github_account(
&self,
- time_period: Range<OffsetDateTime>,
- max_user_count: usize,
- ) -> Result<Vec<UserActivitySummary>> {
- let query = "
- WITH
- project_durations AS (
- SELECT user_id, project_id, SUM(duration_millis) AS project_duration
- FROM project_activity_periods
- WHERE $1 < ended_at AND ended_at <= $2
- GROUP BY user_id, project_id
- ),
- user_durations AS (
- SELECT user_id, SUM(project_duration) as total_duration
- FROM project_durations
- GROUP BY user_id
- ORDER BY total_duration DESC
- LIMIT $3
- ),
- project_collaborators as (
- SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
- FROM project_durations
- GROUP BY project_id
+ github_login: &str,
+ github_user_id: Option<i32>,
+ ) -> Result<Option<User>> {
+ test_support!(self, {
+ if let Some(github_user_id) = github_user_id {
+ let mut user = sqlx::query_as::<_, User>(
+ "
+ UPDATE users
+ SET github_login = $1
+ WHERE github_user_id = $2
+ RETURNING *
+ ",
)
- SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
- FROM user_durations, project_durations, project_collaborators, users
- WHERE
- user_durations.user_id = project_durations.user_id AND
- user_durations.user_id = users.id AND
- project_durations.project_id = project_collaborators.project_id
- ORDER BY total_duration DESC, user_id ASC, project_id ASC
- ";
-
- let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
- .bind(time_period.start)
- .bind(time_period.end)
- .bind(max_user_count as i32)
- .fetch(&self.pool);
-
- let mut result = Vec::<UserActivitySummary>::new();
- while let Some(row) = rows.next().await {
- let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
- let project_id = project_id;
- let duration = Duration::from_millis(duration_millis as u64);
- let project_activity = ProjectActivitySummary {
- id: project_id,
- duration,
- max_collaborators: project_collaborators as usize,
- };
- if let Some(last_summary) = result.last_mut() {
- if last_summary.id == user_id {
- last_summary.project_activity.push(project_activity);
- continue;
- }
- }
- result.push(UserActivitySummary {
- id: user_id,
- project_activity: vec![project_activity],
- github_login,
- });
- }
-
- Ok(result)
- }
+ .bind(github_login)
+ .bind(github_user_id)
+ .fetch_optional(&self.pool)
+ .await?;
- async fn get_user_activity_timeline(
- &self,
- time_period: Range<OffsetDateTime>,
- user_id: UserId,
- ) -> Result<Vec<UserActivityPeriod>> {
- const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
-
- let query = "
- SELECT
- project_activity_periods.ended_at,
- project_activity_periods.duration_millis,
- project_activity_periods.project_id,
- worktree_extensions.extension,
- worktree_extensions.count
- FROM project_activity_periods
- LEFT OUTER JOIN
- worktree_extensions
- ON
- project_activity_periods.project_id = worktree_extensions.project_id
- WHERE
- project_activity_periods.user_id = $1 AND
- $2 < project_activity_periods.ended_at AND
- project_activity_periods.ended_at <= $3
- ORDER BY project_activity_periods.id ASC
- ";
-
- let mut rows = sqlx::query_as::<
- _,
- (
- PrimitiveDateTime,
- i32,
- ProjectId,
- Option<String>,
- Option<i32>,
- ),
- >(query)
- .bind(user_id)
- .bind(time_period.start)
- .bind(time_period.end)
- .fetch(&self.pool);
-
- let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
- while let Some(row) = rows.next().await {
- let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
- let ended_at = ended_at.assume_utc();
- let duration = Duration::from_millis(duration_millis as u64);
- let started_at = ended_at - duration;
- let project_time_periods = time_periods.entry(project_id).or_default();
-
- if let Some(prev_duration) = project_time_periods.last_mut() {
- if started_at <= prev_duration.end + COALESCE_THRESHOLD
- && ended_at >= prev_duration.start
- {
- prev_duration.end = cmp::max(prev_duration.end, ended_at);
- } else {
- project_time_periods.push(UserActivityPeriod {
- project_id,
- start: started_at,
- end: ended_at,
- extensions: Default::default(),
- });
+ if user.is_none() {
+ user = sqlx::query_as::<_, User>(
+ "
+ UPDATE users
+ SET github_user_id = $1
+ WHERE github_login = $2
+ RETURNING *
+ ",
+ )
+ .bind(github_user_id)
+ .bind(github_login)
+ .fetch_optional(&self.pool)
+ .await?;
}
+
+ Ok(user)
} else {
- project_time_periods.push(UserActivityPeriod {
- project_id,
- start: started_at,
- end: ended_at,
- extensions: Default::default(),
- });
+ let user = sqlx::query_as(
+ "
+ SELECT * FROM users
+ WHERE github_login = $1
+ LIMIT 1
+ ",
+ )
+ .bind(github_login)
+ .fetch_optional(&self.pool)
+ .await?;
+ Ok(user)
}
+ })
+ }
- if let Some((extension, extension_count)) = extension.zip(extension_count) {
- project_time_periods
- .last_mut()
- .unwrap()
- .extensions
- .insert(extension, extension_count as usize);
- }
- }
+ pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+ test_support!(self, {
+ let query = "UPDATE users SET admin = $1 WHERE id = $2";
+ Ok(sqlx::query(query)
+ .bind(is_admin)
+ .bind(id.0)
+ .execute(&self.pool)
+ .await
+ .map(drop)?)
+ })
+ }
- let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
- durations.sort_unstable_by_key(|duration| duration.start);
- Ok(durations)
+ pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
+ test_support!(self, {
+ let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
+ Ok(sqlx::query(query)
+ .bind(connected_once)
+ .bind(id.0)
+ .execute(&self.pool)
+ .await
+ .map(drop)?)
+ })
}
- // contacts
+ pub async fn destroy_user(&self, id: UserId) -> Result<()> {
+ test_support!(self, {
+ let query = "DELETE FROM access_tokens WHERE user_id = $1;";
+ sqlx::query(query)
+ .bind(id.0)
+ .execute(&self.pool)
+ .await
+ .map(drop)?;
+ let query = "DELETE FROM users WHERE id = $1;";
+ Ok(sqlx::query(query)
+ .bind(id.0)
+ .execute(&self.pool)
+ .await
+ .map(drop)?)
+ })
+ }
+
+ // signups
- async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
- let query = "
- SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
- FROM contacts
- WHERE user_id_a = $1 OR user_id_b = $1;
- ";
+ pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
+ test_support!(self, {
+ Ok(sqlx::query_as(
+ "
+ SELECT
+ COUNT(*) as count,
+ COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
+ COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
+ COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
+ COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
+ FROM (
+ SELECT *
+ FROM signups
+ WHERE
+ NOT email_confirmation_sent
+ ) AS unsent
+ ",
+ )
+ .fetch_one(&self.pool)
+ .await?)
+ })
+ }
- let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
- .bind(user_id)
- .fetch(&self.pool);
+ pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
+ test_support!(self, {
+ Ok(sqlx::query_as(
+ "
+ SELECT
+ email_address, email_confirmation_code
+ FROM signups
+ WHERE
+ NOT email_confirmation_sent AND
+ (platform_mac OR platform_unknown)
+ LIMIT $1
+ ",
+ )
+ .bind(count as i32)
+ .fetch_all(&self.pool)
+ .await?)
+ })
+ }
- let mut contacts = Vec::new();
- while let Some(row) = rows.next().await {
- let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
+ // invite codes
- if user_id_a == user_id {
- if accepted {
- contacts.push(Contact::Accepted {
- user_id: user_id_b,
- should_notify: should_notify && a_to_b,
- });
- } else if a_to_b {
- contacts.push(Contact::Outgoing { user_id: user_id_b })
- } else {
- contacts.push(Contact::Incoming {
- user_id: user_id_b,
- should_notify,
- });
- }
- } else if accepted {
- contacts.push(Contact::Accepted {
- user_id: user_id_a,
- should_notify: should_notify && !a_to_b,
- });
- } else if a_to_b {
- contacts.push(Contact::Incoming {
- user_id: user_id_a,
- should_notify,
- });
- } else {
- contacts.push(Contact::Outgoing { user_id: user_id_a });
+ pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
+ test_support!(self, {
+ let mut tx = self.pool.begin().await?;
+ if count > 0 {
+ sqlx::query(
+ "
+ UPDATE users
+ SET invite_code = $1
+ WHERE id = $2 AND invite_code IS NULL
+ ",
+ )
+ .bind(random_invite_code())
+ .bind(id)
+ .execute(&mut tx)
+ .await?;
}
- }
-
- contacts.sort_unstable_by_key(|contact| contact.user_id());
- Ok(contacts)
+ sqlx::query(
+ "
+ UPDATE users
+ SET invite_count = $1
+ WHERE id = $2
+ ",
+ )
+ .bind(count as i32)
+ .bind(id)
+ .execute(&mut tx)
+ .await?;
+ tx.commit().await?;
+ Ok(())
+ })
}
- async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
- let (id_a, id_b) = if user_id_1 < user_id_2 {
- (user_id_1, user_id_2)
- } else {
- (user_id_2, user_id_1)
- };
+ pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
+ test_support!(self, {
+ let result: Option<(String, i32)> = sqlx::query_as(
+ "
+ SELECT invite_code, invite_count
+ FROM users
+ WHERE id = $1 AND invite_code IS NOT NULL
+ ",
+ )
+ .bind(id)
+ .fetch_optional(&self.pool)
+ .await?;
+ if let Some((code, count)) = result {
+ Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
+ } else {
+ Ok(None)
+ }
+ })
+ }
- let query = "
- SELECT 1 FROM contacts
- WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
- LIMIT 1
- ";
- Ok(sqlx::query_scalar::<_, i32>(query)
- .bind(id_a.0)
- .bind(id_b.0)
+ pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
+ test_support!(self, {
+ sqlx::query_as(
+ "
+ SELECT *
+ FROM users
+ WHERE invite_code = $1
+ ",
+ )
+ .bind(code)
.fetch_optional(&self.pool)
.await?
- .is_some())
+ .ok_or_else(|| {
+ Error::Http(
+ StatusCode::NOT_FOUND,
+ "that invite code does not exist".to_string(),
+ )
+ })
+ })
}
- async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
- let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
- (sender_id, receiver_id, true)
- } else {
- (receiver_id, sender_id, false)
- };
- let query = "
- INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
- VALUES ($1, $2, $3, 'f', 't')
- ON CONFLICT (user_id_a, user_id_b) DO UPDATE
- SET
- accepted = 't',
- should_notify = 'f'
- WHERE
- NOT contacts.accepted AND
- ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
- (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
- ";
- let result = sqlx::query(query)
- .bind(id_a.0)
- .bind(id_b.0)
- .bind(a_to_b)
- .execute(&self.pool)
- .await?;
+ // projects
- if result.rows_affected() == 1 {
- Ok(())
- } else {
- Err(anyhow!("contact already requested"))?
- }
+ /// Registers a new project for the given user.
+ pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
+ test_support!(self, {
+ Ok(sqlx::query_scalar(
+ "
+ INSERT INTO projects(host_user_id)
+ VALUES ($1)
+ RETURNING id
+ ",
+ )
+ .bind(host_user_id)
+ .fetch_one(&self.pool)
+ .await
+ .map(ProjectId)?)
+ })
}
- async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
- let (id_a, id_b) = if responder_id < requester_id {
- (responder_id, requester_id)
- } else {
- (requester_id, responder_id)
- };
- let query = "
- DELETE FROM contacts
- WHERE user_id_a = $1 AND user_id_b = $2;
- ";
- let result = sqlx::query(query)
- .bind(id_a.0)
- .bind(id_b.0)
+ /// Unregisters a project for the given project id.
+ pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
+ test_support!(self, {
+ sqlx::query(
+ "
+ UPDATE projects
+ SET unregistered = TRUE
+ WHERE id = $1
+ ",
+ )
+ .bind(project_id)
.execute(&self.pool)
.await?;
-
- if result.rows_affected() == 1 {
Ok(())
- } else {
- Err(anyhow!("no such contact"))?
- }
+ })
}
- async fn dismiss_contact_notification(
- &self,
- user_id: UserId,
- contact_user_id: UserId,
- ) -> Result<()> {
- let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
- (user_id, contact_user_id, true)
- } else {
- (contact_user_id, user_id, false)
- };
+ // contacts
- let query = "
- UPDATE contacts
- SET should_notify = 'f'
- WHERE
- user_id_a = $1 AND user_id_b = $2 AND
- (
- (a_to_b = $3 AND accepted) OR
- (a_to_b != $3 AND NOT accepted)
- );
- ";
-
- let result = sqlx::query(query)
- .bind(id_a.0)
- .bind(id_b.0)
- .bind(a_to_b)
- .execute(&self.pool)
- .await?;
+ pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
+ test_support!(self, {
+ let query = "
+ SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
+ FROM contacts
+ WHERE user_id_a = $1 OR user_id_b = $1;
+ ";
- if result.rows_affected() == 0 {
- Err(anyhow!("no such contact request"))?;
- }
+ let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
+ .bind(user_id)
+ .fetch(&self.pool);
+
+ let mut contacts = Vec::new();
+ while let Some(row) = rows.next().await {
+ let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
+
+ if user_id_a == user_id {
+ if accepted {
+ contacts.push(Contact::Accepted {
+ user_id: user_id_b,
+ should_notify: should_notify && a_to_b,
+ });
+ } else if a_to_b {
+ contacts.push(Contact::Outgoing { user_id: user_id_b })
+ } else {
+ contacts.push(Contact::Incoming {
+ user_id: user_id_b,
+ should_notify,
+ });
+ }
+ } else if accepted {
+ contacts.push(Contact::Accepted {
+ user_id: user_id_a,
+ should_notify: should_notify && !a_to_b,
+ });
+ } else if a_to_b {
+ contacts.push(Contact::Incoming {
+ user_id: user_id_a,
+ should_notify,
+ });
+ } else {
+ contacts.push(Contact::Outgoing { user_id: user_id_a });
+ }
+ }
+
+ contacts.sort_unstable_by_key(|contact| contact.user_id());
- Ok(())
+ Ok(contacts)
+ })
}
- async fn respond_to_contact_request(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- accept: bool,
- ) -> Result<()> {
- let (id_a, id_b, a_to_b) = if responder_id < requester_id {
- (responder_id, requester_id, false)
- } else {
- (requester_id, responder_id, true)
- };
- let result = if accept {
+ pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
+ test_support!(self, {
+ let (id_a, id_b) = if user_id_1 < user_id_2 {
+ (user_id_1, user_id_2)
+ } else {
+ (user_id_2, user_id_1)
+ };
+
let query = "
- UPDATE contacts
- SET accepted = 't', should_notify = 't'
- WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+ SELECT 1 FROM contacts
+ WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
+ LIMIT 1
";
- sqlx::query(query)
+ Ok(sqlx::query_scalar::<_, i32>(query)
.bind(id_a.0)
.bind(id_b.0)
- .bind(a_to_b)
- .execute(&self.pool)
+ .fetch_optional(&self.pool)
.await?
- } else {
+ .is_some())
+ })
+ }
+
+ pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
+ test_support!(self, {
+ let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
+ (sender_id, receiver_id, true)
+ } else {
+ (receiver_id, sender_id, false)
+ };
let query = "
- DELETE FROM contacts
- WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
+ INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
+ VALUES ($1, $2, $3, FALSE, TRUE)
+ ON CONFLICT (user_id_a, user_id_b) DO UPDATE
+ SET
+ accepted = TRUE,
+ should_notify = FALSE
+ WHERE
+ NOT contacts.accepted AND
+ ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
+ (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
";
- sqlx::query(query)
+ let result = sqlx::query(query)
.bind(id_a.0)
.bind(id_b.0)
.bind(a_to_b)
.execute(&self.pool)
- .await?
- };
- if result.rows_affected() == 1 {
- Ok(())
- } else {
- Err(anyhow!("no such contact request"))?
- }
+ .await?;
+
+ if result.rows_affected() == 1 {
+ Ok(())
+ } else {
+ Err(anyhow!("contact already requested"))?
+ }
+ })
}
- // access tokens
+ pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
+ test_support!(self, {
+ let (id_a, id_b) = if responder_id < requester_id {
+ (responder_id, requester_id)
+ } else {
+ (requester_id, responder_id)
+ };
+ let query = "
+ DELETE FROM contacts
+ WHERE user_id_a = $1 AND user_id_b = $2;
+ ";
+ let result = sqlx::query(query)
+ .bind(id_a.0)
+ .bind(id_b.0)
+ .execute(&self.pool)
+ .await?;
+
+ if result.rows_affected() == 1 {
+ Ok(())
+ } else {
+ Err(anyhow!("no such contact"))?
+ }
+ })
+ }
- async fn create_access_token_hash(
+ pub async fn dismiss_contact_notification(
&self,
user_id: UserId,
- access_token_hash: &str,
- max_access_token_count: usize,
+ contact_user_id: UserId,
) -> Result<()> {
- let insert_query = "
- INSERT INTO access_tokens (user_id, hash)
- VALUES ($1, $2);
- ";
- let cleanup_query = "
- DELETE FROM access_tokens
- WHERE id IN (
- SELECT id from access_tokens
- WHERE user_id = $1
- ORDER BY id DESC
- OFFSET $3
- )
- ";
-
- let mut tx = self.pool.begin().await?;
- sqlx::query(insert_query)
- .bind(user_id.0)
- .bind(access_token_hash)
- .execute(&mut tx)
- .await?;
- sqlx::query(cleanup_query)
- .bind(user_id.0)
- .bind(access_token_hash)
- .bind(max_access_token_count as i32)
- .execute(&mut tx)
- .await?;
- Ok(tx.commit().await?)
- }
-
- async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
- let query = "
- SELECT hash
- FROM access_tokens
- WHERE user_id = $1
- ORDER BY id DESC
- ";
- Ok(sqlx::query_scalar(query)
- .bind(user_id.0)
- .fetch_all(&self.pool)
- .await?)
- }
-
- // orgs
-
- #[allow(unused)] // Help rust-analyzer
- #[cfg(any(test, feature = "seed-support"))]
- async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
- let query = "
- SELECT *
- FROM orgs
- WHERE slug = $1
- ";
- Ok(sqlx::query_as(query)
- .bind(slug)
- .fetch_optional(&self.pool)
- .await?)
- }
+ test_support!(self, {
+ let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
+ (user_id, contact_user_id, true)
+ } else {
+ (contact_user_id, user_id, false)
+ };
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
- let query = "
- INSERT INTO orgs (name, slug)
- VALUES ($1, $2)
- RETURNING id
- ";
- Ok(sqlx::query_scalar(query)
- .bind(name)
- .bind(slug)
- .fetch_one(&self.pool)
- .await
- .map(OrgId)?)
- }
+ let query = "
+ UPDATE contacts
+ SET should_notify = FALSE
+ WHERE
+ user_id_a = $1 AND user_id_b = $2 AND
+ (
+ (a_to_b = $3 AND accepted) OR
+ (a_to_b != $3 AND NOT accepted)
+ );
+ ";
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
- let query = "
- INSERT INTO org_memberships (org_id, user_id, admin)
- VALUES ($1, $2, $3)
- ON CONFLICT DO NOTHING
- ";
- Ok(sqlx::query(query)
- .bind(org_id.0)
- .bind(user_id.0)
- .bind(is_admin)
- .execute(&self.pool)
- .await
- .map(drop)?)
- }
+ let result = sqlx::query(query)
+ .bind(id_a.0)
+ .bind(id_b.0)
+ .bind(a_to_b)
+ .execute(&self.pool)
+ .await?;
- // channels
-
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
- let query = "
- INSERT INTO channels (owner_id, owner_is_user, name)
- VALUES ($1, false, $2)
- RETURNING id
- ";
- Ok(sqlx::query_scalar(query)
- .bind(org_id.0)
- .bind(name)
- .fetch_one(&self.pool)
- .await
- .map(ChannelId)?)
- }
+ if result.rows_affected() == 0 {
+ Err(anyhow!("no such contact request"))?;
+ }
- #[allow(unused)] // Help rust-analyzer
- #[cfg(any(test, feature = "seed-support"))]
- async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
- let query = "
- SELECT *
- FROM channels
- WHERE
- channels.owner_is_user = false AND
- channels.owner_id = $1
- ";
- Ok(sqlx::query_as(query)
- .bind(org_id.0)
- .fetch_all(&self.pool)
- .await?)
+ Ok(())
+ })
}
- async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
- let query = "
- SELECT
- channels.*
- FROM
- channel_memberships, channels
- WHERE
- channel_memberships.user_id = $1 AND
- channel_memberships.channel_id = channels.id
- ";
- Ok(sqlx::query_as(query)
- .bind(user_id.0)
- .fetch_all(&self.pool)
- .await?)
+ pub async fn respond_to_contact_request(
+ &self,
+ responder_id: UserId,
+ requester_id: UserId,
+ accept: bool,
+ ) -> Result<()> {
+ test_support!(self, {
+ let (id_a, id_b, a_to_b) = if responder_id < requester_id {
+ (responder_id, requester_id, false)
+ } else {
+ (requester_id, responder_id, true)
+ };
+ let result = if accept {
+ let query = "
+ UPDATE contacts
+ SET accepted = TRUE, should_notify = TRUE
+ WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+ ";
+ sqlx::query(query)
+ .bind(id_a.0)
+ .bind(id_b.0)
+ .bind(a_to_b)
+ .execute(&self.pool)
+ .await?
+ } else {
+ let query = "
+ DELETE FROM contacts
+ WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
+ ";
+ sqlx::query(query)
+ .bind(id_a.0)
+ .bind(id_b.0)
+ .bind(a_to_b)
+ .execute(&self.pool)
+ .await?
+ };
+ if result.rows_affected() == 1 {
+ Ok(())
+ } else {
+ Err(anyhow!("no such contact request"))?
+ }
+ })
}
- async fn can_user_access_channel(
- &self,
- user_id: UserId,
- channel_id: ChannelId,
- ) -> Result<bool> {
- let query = "
- SELECT id
- FROM channel_memberships
- WHERE user_id = $1 AND channel_id = $2
- LIMIT 1
- ";
- Ok(sqlx::query_scalar::<_, i32>(query)
- .bind(user_id.0)
- .bind(channel_id.0)
- .fetch_optional(&self.pool)
- .await
- .map(|e| e.is_some())?)
- }
+ // access tokens
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_channel_member(
+ pub async fn create_access_token_hash(
&self,
- channel_id: ChannelId,
user_id: UserId,
- is_admin: bool,
+ access_token_hash: &str,
+ max_access_token_count: usize,
) -> Result<()> {
- let query = "
- INSERT INTO channel_memberships (channel_id, user_id, admin)
- VALUES ($1, $2, $3)
- ON CONFLICT DO NOTHING
- ";
- Ok(sqlx::query(query)
- .bind(channel_id.0)
- .bind(user_id.0)
- .bind(is_admin)
- .execute(&self.pool)
- .await
- .map(drop)?)
- }
-
- // messages
+ test_support!(self, {
+ let insert_query = "
+ INSERT INTO access_tokens (user_id, hash)
+ VALUES ($1, $2);
+ ";
+ let cleanup_query = "
+ DELETE FROM access_tokens
+ WHERE id IN (
+ SELECT id from access_tokens
+ WHERE user_id = $1
+ ORDER BY id DESC
+ LIMIT 10000
+ OFFSET $3
+ )
+ ";
- async fn create_channel_message(
- &self,
- channel_id: ChannelId,
- sender_id: UserId,
- body: &str,
- timestamp: OffsetDateTime,
- nonce: u128,
- ) -> Result<MessageId> {
- let query = "
- INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
- VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
- RETURNING id
- ";
- Ok(sqlx::query_scalar(query)
- .bind(channel_id.0)
- .bind(sender_id.0)
- .bind(body)
- .bind(timestamp)
- .bind(Uuid::from_u128(nonce))
- .fetch_one(&self.pool)
- .await
- .map(MessageId)?)
+ let mut tx = self.pool.begin().await?;
+ sqlx::query(insert_query)
+ .bind(user_id.0)
+ .bind(access_token_hash)
+ .execute(&mut tx)
+ .await?;
+ sqlx::query(cleanup_query)
+ .bind(user_id.0)
+ .bind(access_token_hash)
+ .bind(max_access_token_count as i32)
+ .execute(&mut tx)
+ .await?;
+ Ok(tx.commit().await?)
+ })
}
- async fn get_channel_messages(
- &self,
- channel_id: ChannelId,
- count: usize,
- before_id: Option<MessageId>,
- ) -> Result<Vec<ChannelMessage>> {
- let query = r#"
- SELECT * FROM (
- SELECT
- id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
- FROM
- channel_messages
- WHERE
- channel_id = $1 AND
- id < $2
+ pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+ test_support!(self, {
+ let query = "
+ SELECT hash
+ FROM access_tokens
+ WHERE user_id = $1
ORDER BY id DESC
- LIMIT $3
- ) as recent_messages
- ORDER BY id ASC
- "#;
- Ok(sqlx::query_as(query)
- .bind(channel_id.0)
- .bind(before_id.unwrap_or(MessageId::MAX))
- .bind(count as i64)
- .fetch_all(&self.pool)
- .await?)
- }
-
- #[cfg(test)]
- async fn teardown(&self, url: &str) {
- use util::ResultExt;
-
- let query = "
- SELECT pg_terminate_backend(pg_stat_activity.pid)
- FROM pg_stat_activity
- WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
- ";
- sqlx::query(query).execute(&self.pool).await.log_err();
- self.pool.close().await;
- <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
- .await
- .log_err();
- }
-
- #[cfg(test)]
- fn as_fake(&self) -> Option<&FakeDb> {
- None
+ ";
+ Ok(sqlx::query_scalar(query)
+ .bind(user_id.0)
+ .fetch_all(&self.pool)
+ .await?)
+ })
}
}
@@ -1,17 +1,30 @@
use super::db::*;
-use collections::HashMap;
use gpui::executor::{Background, Deterministic};
-use std::{sync::Arc, time::Duration};
-use time::OffsetDateTime;
+use std::sync::Arc;
+
+macro_rules! test_both_dbs {
+ ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => {
+ #[gpui::test]
+ async fn $postgres_test_name() {
+ let test_db = PostgresTestDb::new(Deterministic::new(0).build_background());
+ let $db = test_db.db();
+ $body
+ }
-#[tokio::test(flavor = "multi_thread")]
-async fn test_get_users_by_ids() {
- for test_db in [
- TestDb::postgres().await,
- TestDb::fake(build_background_executor()),
- ] {
- let db = test_db.db();
+ #[gpui::test]
+ async fn $sqlite_test_name() {
+ let test_db = SqliteTestDb::new(Deterministic::new(0).build_background());
+ let $db = test_db.db();
+ $body
+ }
+ };
+}
+test_both_dbs!(
+ test_get_users_by_ids_postgres,
+ test_get_users_by_ids_sqlite,
+ db,
+ {
let mut user_ids = Vec::new();
for i in 1..=4 {
user_ids.push(
@@ -68,15 +81,13 @@ async fn test_get_users_by_ids() {
]
);
}
-}
+);
-#[tokio::test(flavor = "multi_thread")]
-async fn test_get_user_by_github_account() {
- for test_db in [
- TestDb::postgres().await,
- TestDb::fake(build_background_executor()),
- ] {
- let db = test_db.db();
+test_both_dbs!(
+ test_get_user_by_github_account_postgres,
+ test_get_user_by_github_account_sqlite,
+ db,
+ {
let user_id1 = db
.create_user(
"user1@example.com",
@@ -128,87 +139,57 @@ async fn test_get_user_by_github_account() {
assert_eq!(&user.github_login, "the-new-login2");
assert_eq!(user.github_user_id, Some(102));
}
-}
+);
-#[tokio::test(flavor = "multi_thread")]
-async fn test_worktree_extensions() {
- let test_db = TestDb::postgres().await;
- let db = test_db.db();
+test_both_dbs!(
+ test_create_access_tokens_postgres,
+ test_create_access_tokens_sqlite,
+ db,
+ {
+ let user = db
+ .create_user(
+ "u1@example.com",
+ false,
+ NewUserParams {
+ github_login: "u1".into(),
+ github_user_id: 1,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
- let user = db
- .create_user(
- "u1@example.com",
- false,
- NewUserParams {
- github_login: "u1".into(),
- github_user_id: 0,
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id;
- let project = db.register_project(user).await.unwrap();
+ db.create_access_token_hash(user, "h1", 3).await.unwrap();
+ db.create_access_token_hash(user, "h2", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h2".to_string(), "h1".to_string()]
+ );
- db.update_worktree_extensions(project, 100, Default::default())
- .await
- .unwrap();
- db.update_worktree_extensions(
- project,
- 100,
- [("rs".to_string(), 5), ("md".to_string(), 3)]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
- db.update_worktree_extensions(
- project,
- 100,
- [("rs".to_string(), 6), ("md".to_string(), 5)]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
- db.update_worktree_extensions(
- project,
- 101,
- [("ts".to_string(), 2), ("md".to_string(), 1)]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
+ db.create_access_token_hash(user, "h3", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+ );
- assert_eq!(
- db.get_project_extensions(project).await.unwrap(),
- [
- (
- 100,
- [("rs".into(), 6), ("md".into(), 5),]
- .into_iter()
- .collect::<HashMap<_, _>>()
- ),
- (
- 101,
- [("ts".into(), 2), ("md".into(), 1),]
- .into_iter()
- .collect::<HashMap<_, _>>()
- )
- ]
- .into_iter()
- .collect()
- );
-}
+ db.create_access_token_hash(user, "h4", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+ );
-#[tokio::test(flavor = "multi_thread")]
-async fn test_user_activity() {
- let test_db = TestDb::postgres().await;
- let db = test_db.db();
+ db.create_access_token_hash(user, "h5", 3).await.unwrap();
+ assert_eq!(
+ db.get_access_token_hashes(user).await.unwrap(),
+ &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+ );
+ }
+);
+test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, {
let mut user_ids = Vec::new();
- for i in 0..=2 {
+ for i in 0..3 {
user_ids.push(
db.create_user(
&format!("user{i}@example.com"),
@@ -225,371 +206,198 @@ async fn test_user_activity() {
);
}
- let project_1 = db.register_project(user_ids[0]).await.unwrap();
- db.update_worktree_extensions(
- project_1,
- 1,
- HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
- )
- .await
- .unwrap();
- let project_2 = db.register_project(user_ids[1]).await.unwrap();
- let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
-
- // User 2 opens a project
- let t1 = t0 + Duration::from_secs(10);
- db.record_user_activity(t0..t1, &[(user_ids[1], project_2)])
- .await
- .unwrap();
-
- let t2 = t1 + Duration::from_secs(10);
- db.record_user_activity(t1..t2, &[(user_ids[1], project_2)])
- .await
- .unwrap();
-
- // User 1 joins the project
- let t3 = t2 + Duration::from_secs(10);
- db.record_user_activity(
- t2..t3,
- &[(user_ids[1], project_2), (user_ids[0], project_2)],
- )
- .await
- .unwrap();
+ let user_1 = user_ids[0];
+ let user_2 = user_ids[1];
+ let user_3 = user_ids[2];
- // User 1 opens another project
- let t4 = t3 + Duration::from_secs(10);
- db.record_user_activity(
- t3..t4,
- &[
- (user_ids[1], project_2),
- (user_ids[0], project_2),
- (user_ids[0], project_1),
- ],
- )
- .await
- .unwrap();
+ // User starts with no contacts
+ assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]);
- // User 3 joins that project
- let t5 = t4 + Duration::from_secs(10);
- db.record_user_activity(
- t4..t5,
- &[
- (user_ids[1], project_2),
- (user_ids[0], project_2),
- (user_ids[0], project_1),
- (user_ids[2], project_1),
- ],
- )
- .await
- .unwrap();
-
- // User 2 leaves
- let t6 = t5 + Duration::from_secs(5);
- db.record_user_activity(
- t5..t6,
- &[(user_ids[0], project_1), (user_ids[2], project_1)],
- )
- .await
- .unwrap();
-
- let t7 = t6 + Duration::from_secs(60);
- let t8 = t7 + Duration::from_secs(10);
- db.record_user_activity(t7..t8, &[(user_ids[0], project_1)])
- .await
- .unwrap();
-
- assert_eq!(
- db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
- &[
- UserActivitySummary {
- id: user_ids[0],
- github_login: "user0".to_string(),
- project_activity: vec![
- ProjectActivitySummary {
- id: project_1,
- duration: Duration::from_secs(25),
- max_collaborators: 2
- },
- ProjectActivitySummary {
- id: project_2,
- duration: Duration::from_secs(30),
- max_collaborators: 2
- }
- ]
- },
- UserActivitySummary {
- id: user_ids[1],
- github_login: "user1".to_string(),
- project_activity: vec![ProjectActivitySummary {
- id: project_2,
- duration: Duration::from_secs(50),
- max_collaborators: 2
- }]
- },
- UserActivitySummary {
- id: user_ids[2],
- github_login: "user2".to_string(),
- project_activity: vec![ProjectActivitySummary {
- id: project_1,
- duration: Duration::from_secs(15),
- max_collaborators: 2
- }]
- },
- ]
- );
-
- assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
- .await
- .unwrap(),
- 0
- );
- assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
- .await
- .unwrap(),
- 0
- );
- assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
- .await
- .unwrap(),
- 1
- );
+ // User requests a contact. Both users see the pending request.
+ db.send_contact_request(user_1, user_2).await.unwrap();
+ assert!(!db.has_contact(user_1, user_2).await.unwrap());
+ assert!(!db.has_contact(user_2, user_1).await.unwrap());
assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
- .await
- .unwrap(),
- 1
+ db.get_contacts(user_1).await.unwrap(),
+ &[Contact::Outgoing { user_id: user_2 }],
);
assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
- .await
- .unwrap(),
- 2
+ db.get_contacts(user_2).await.unwrap(),
+ &[Contact::Incoming {
+ user_id: user_1,
+ should_notify: true
+ }]
);
+
+ // User 2 dismisses the contact request notification without accepting or rejecting.
+ // We shouldn't notify them again.
+ db.dismiss_contact_notification(user_1, user_2)
+ .await
+ .unwrap_err();
+ db.dismiss_contact_notification(user_2, user_1)
+ .await
+ .unwrap();
assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
- .await
- .unwrap(),
- 2
+ db.get_contacts(user_2).await.unwrap(),
+ &[Contact::Incoming {
+ user_id: user_1,
+ should_notify: false
+ }]
);
+
+ // User can't accept their own contact request
+ db.respond_to_contact_request(user_1, user_2, true)
+ .await
+ .unwrap_err();
+
+ // User accepts a contact request. Both users see the contact.
+ db.respond_to_contact_request(user_2, user_1, true)
+ .await
+ .unwrap();
assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
- .await
- .unwrap(),
- 3
+ db.get_contacts(user_1).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_2,
+ should_notify: true
+ }],
);
+ assert!(db.has_contact(user_1, user_2).await.unwrap());
+ assert!(db.has_contact(user_2, user_1).await.unwrap());
assert_eq!(
- db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
- .await
- .unwrap(),
- 3
+ db.get_contacts(user_2).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_1,
+ should_notify: false,
+ }]
);
+
+ // Users cannot re-request existing contacts.
+ db.send_contact_request(user_1, user_2).await.unwrap_err();
+ db.send_contact_request(user_2, user_1).await.unwrap_err();
+
+ // Users can't dismiss notifications of them accepting other users' requests.
+ db.dismiss_contact_notification(user_2, user_1)
+ .await
+ .unwrap_err();
assert_eq!(
- db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
- .await
- .unwrap(),
- 1
+ db.get_contacts(user_1).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_2,
+ should_notify: true,
+ }]
);
+
+ // Users can dismiss notifications of other users accepting their requests.
+ db.dismiss_contact_notification(user_1, user_2)
+ .await
+ .unwrap();
assert_eq!(
- db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
- .await
- .unwrap(),
- 0
+ db.get_contacts(user_1).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_2,
+ should_notify: false,
+ }]
);
+ // Users send each other concurrent contact requests and
+ // see that they are immediately accepted.
+ db.send_contact_request(user_1, user_3).await.unwrap();
+ db.send_contact_request(user_3, user_1).await.unwrap();
assert_eq!(
- db.get_user_activity_timeline(t3..t6, user_ids[0])
- .await
- .unwrap(),
+ db.get_contacts(user_1).await.unwrap(),
&[
- UserActivityPeriod {
- project_id: project_1,
- start: t3,
- end: t6,
- extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
- },
- UserActivityPeriod {
- project_id: project_2,
- start: t3,
- end: t5,
- extensions: Default::default(),
+ Contact::Accepted {
+ user_id: user_2,
+ should_notify: false,
},
+ Contact::Accepted {
+ user_id: user_3,
+ should_notify: false
+ }
]
);
assert_eq!(
- db.get_user_activity_timeline(t0..t8, user_ids[0])
- .await
- .unwrap(),
- &[
- UserActivityPeriod {
- project_id: project_2,
- start: t2,
- end: t5,
- extensions: Default::default(),
- },
- UserActivityPeriod {
- project_id: project_1,
- start: t3,
- end: t6,
- extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
- },
- UserActivityPeriod {
- project_id: project_1,
- start: t7,
- end: t8,
- extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
- },
- ]
+ db.get_contacts(user_3).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_1,
+ should_notify: false
+ }],
);
-}
-
-#[tokio::test(flavor = "multi_thread")]
-async fn test_recent_channel_messages() {
- for test_db in [
- TestDb::postgres().await,
- TestDb::fake(build_background_executor()),
- ] {
- let db = test_db.db();
- let user = db
- .create_user(
- "u@example.com",
- false,
- NewUserParams {
- github_login: "u".into(),
- github_user_id: 1,
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id;
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
- for i in 0..10 {
- db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
- .await
- .unwrap();
- }
-
- let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
- assert_eq!(
- messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["5", "6", "7", "8", "9"]
- );
-
- let prev_messages = db
- .get_channel_messages(channel, 4, Some(messages[0].id))
- .await
- .unwrap();
- assert_eq!(
- prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["1", "2", "3", "4"]
- );
- }
-}
-
-#[tokio::test(flavor = "multi_thread")]
-async fn test_channel_message_nonces() {
- for test_db in [
- TestDb::postgres().await,
- TestDb::fake(build_background_executor()),
- ] {
- let db = test_db.db();
- let user = db
- .create_user(
- "user@example.com",
- false,
- NewUserParams {
- github_login: "user".into(),
- github_user_id: 1,
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id;
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
-
- let msg1_id = db
- .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg2_id = db
- .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
- let msg3_id = db
- .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg4_id = db
- .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
- assert_ne!(msg1_id, msg2_id);
- assert_eq!(msg1_id, msg3_id);
- assert_eq!(msg2_id, msg4_id);
- }
-}
+ // User declines a contact request. Both users see that it is gone.
+ db.send_contact_request(user_2, user_3).await.unwrap();
+ db.respond_to_contact_request(user_3, user_2, false)
+ .await
+ .unwrap();
+ assert!(!db.has_contact(user_2, user_3).await.unwrap());
+ assert!(!db.has_contact(user_3, user_2).await.unwrap());
+ assert_eq!(
+ db.get_contacts(user_2).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_1,
+ should_notify: false
+ }]
+ );
+ assert_eq!(
+ db.get_contacts(user_3).await.unwrap(),
+ &[Contact::Accepted {
+ user_id: user_1,
+ should_notify: false
+ }],
+ );
+});
-#[tokio::test(flavor = "multi_thread")]
-async fn test_create_access_tokens() {
- let test_db = TestDb::postgres().await;
- let db = test_db.db();
- let user = db
+test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, {
+ let NewUserResult {
+ user_id: user1,
+ metrics_id: metrics_id1,
+ ..
+ } = db
.create_user(
- "u1@example.com",
+ "person1@example.com",
false,
NewUserParams {
- github_login: "u1".into(),
- github_user_id: 1,
- invite_count: 0,
+ github_login: "person1".into(),
+ github_user_id: 101,
+ invite_count: 5,
},
)
.await
- .unwrap()
- .user_id;
-
- db.create_access_token_hash(user, "h1", 3).await.unwrap();
- db.create_access_token_hash(user, "h2", 3).await.unwrap();
- assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h2".to_string(), "h1".to_string()]
- );
-
- db.create_access_token_hash(user, "h3", 3).await.unwrap();
- assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
- );
-
- db.create_access_token_hash(user, "h4", 3).await.unwrap();
- assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
- );
+ .unwrap();
+ let NewUserResult {
+ user_id: user2,
+ metrics_id: metrics_id2,
+ ..
+ } = db
+ .create_user(
+ "person2@example.com",
+ false,
+ NewUserParams {
+ github_login: "person2".into(),
+ github_user_id: 102,
+ invite_count: 5,
+ },
+ )
+ .await
+ .unwrap();
- db.create_access_token_hash(user, "h5", 3).await.unwrap();
- assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h5".to_string(), "h4".to_string(), "h3".to_string()]
- );
-}
+ assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1);
+ assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2);
+ assert_eq!(metrics_id1.len(), 36);
+ assert_eq!(metrics_id2.len(), 36);
+ assert_ne!(metrics_id1, metrics_id2);
+});
#[test]
fn test_fuzzy_like_string() {
- assert_eq!(PostgresDb::fuzzy_like_string("abcd"), "%a%b%c%d%");
- assert_eq!(PostgresDb::fuzzy_like_string("x y"), "%x%y%");
- assert_eq!(PostgresDb::fuzzy_like_string(" z "), "%z%");
+ assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%");
+ assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%");
+ assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%");
}
-#[tokio::test(flavor = "multi_thread")]
+#[gpui::test]
async fn test_fuzzy_search_users() {
- let test_db = TestDb::postgres().await;
+ let test_db = PostgresTestDb::new(build_background_executor());
let db = test_db.db();
for (i, github_login) in [
"California",
@@ -625,7 +433,7 @@ async fn test_fuzzy_search_users() {
&["rhode-island", "colorado", "oregon"],
);
- async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
+ async fn fuzzy_search_user_names(db: &Db<sqlx::Postgres>, query: &str) -> Vec<String> {
db.fuzzy_search_users(query, 10)
.await
.unwrap()
@@ -635,178 +443,11 @@ async fn test_fuzzy_search_users() {
}
}
-#[tokio::test(flavor = "multi_thread")]
-async fn test_add_contacts() {
- for test_db in [
- TestDb::postgres().await,
- TestDb::fake(build_background_executor()),
- ] {
- let db = test_db.db();
-
- let mut user_ids = Vec::new();
- for i in 0..3 {
- user_ids.push(
- db.create_user(
- &format!("user{i}@example.com"),
- false,
- NewUserParams {
- github_login: format!("user{i}"),
- github_user_id: i,
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id,
- );
- }
-
- let user_1 = user_ids[0];
- let user_2 = user_ids[1];
- let user_3 = user_ids[2];
-
- // User starts with no contacts
- assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]);
-
- // User requests a contact. Both users see the pending request.
- db.send_contact_request(user_1, user_2).await.unwrap();
- assert!(!db.has_contact(user_1, user_2).await.unwrap());
- assert!(!db.has_contact(user_2, user_1).await.unwrap());
- assert_eq!(
- db.get_contacts(user_1).await.unwrap(),
- &[Contact::Outgoing { user_id: user_2 }],
- );
- assert_eq!(
- db.get_contacts(user_2).await.unwrap(),
- &[Contact::Incoming {
- user_id: user_1,
- should_notify: true
- }]
- );
-
- // User 2 dismisses the contact request notification without accepting or rejecting.
- // We shouldn't notify them again.
- db.dismiss_contact_notification(user_1, user_2)
- .await
- .unwrap_err();
- db.dismiss_contact_notification(user_2, user_1)
- .await
- .unwrap();
- assert_eq!(
- db.get_contacts(user_2).await.unwrap(),
- &[Contact::Incoming {
- user_id: user_1,
- should_notify: false
- }]
- );
-
- // User can't accept their own contact request
- db.respond_to_contact_request(user_1, user_2, true)
- .await
- .unwrap_err();
-
- // User accepts a contact request. Both users see the contact.
- db.respond_to_contact_request(user_2, user_1, true)
- .await
- .unwrap();
- assert_eq!(
- db.get_contacts(user_1).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_2,
- should_notify: true
- }],
- );
- assert!(db.has_contact(user_1, user_2).await.unwrap());
- assert!(db.has_contact(user_2, user_1).await.unwrap());
- assert_eq!(
- db.get_contacts(user_2).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_1,
- should_notify: false,
- }]
- );
-
- // Users cannot re-request existing contacts.
- db.send_contact_request(user_1, user_2).await.unwrap_err();
- db.send_contact_request(user_2, user_1).await.unwrap_err();
-
- // Users can't dismiss notifications of them accepting other users' requests.
- db.dismiss_contact_notification(user_2, user_1)
- .await
- .unwrap_err();
- assert_eq!(
- db.get_contacts(user_1).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_2,
- should_notify: true,
- }]
- );
-
- // Users can dismiss notifications of other users accepting their requests.
- db.dismiss_contact_notification(user_1, user_2)
- .await
- .unwrap();
- assert_eq!(
- db.get_contacts(user_1).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_2,
- should_notify: false,
- }]
- );
-
- // Users send each other concurrent contact requests and
- // see that they are immediately accepted.
- db.send_contact_request(user_1, user_3).await.unwrap();
- db.send_contact_request(user_3, user_1).await.unwrap();
- assert_eq!(
- db.get_contacts(user_1).await.unwrap(),
- &[
- Contact::Accepted {
- user_id: user_2,
- should_notify: false,
- },
- Contact::Accepted {
- user_id: user_3,
- should_notify: false
- }
- ]
- );
- assert_eq!(
- db.get_contacts(user_3).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_1,
- should_notify: false
- }],
- );
-
- // User declines a contact request. Both users see that it is gone.
- db.send_contact_request(user_2, user_3).await.unwrap();
- db.respond_to_contact_request(user_3, user_2, false)
- .await
- .unwrap();
- assert!(!db.has_contact(user_2, user_3).await.unwrap());
- assert!(!db.has_contact(user_3, user_2).await.unwrap());
- assert_eq!(
- db.get_contacts(user_2).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_1,
- should_notify: false
- }]
- );
- assert_eq!(
- db.get_contacts(user_3).await.unwrap(),
- &[Contact::Accepted {
- user_id: user_1,
- should_notify: false
- }],
- );
- }
-}
-
-#[tokio::test(flavor = "multi_thread")]
+#[gpui::test]
async fn test_invite_codes() {
- let postgres = TestDb::postgres().await;
- let db = postgres.db();
+ let test_db = PostgresTestDb::new(build_background_executor());
+ let db = test_db.db();
+
let NewUserResult { user_id: user1, .. } = db
.create_user(
"user1@example.com",
@@ -998,10 +639,10 @@ async fn test_invite_codes() {
assert_eq!(invite_count, 1);
}
-#[tokio::test(flavor = "multi_thread")]
+#[gpui::test]
async fn test_signups() {
- let postgres = TestDb::postgres().await;
- let db = postgres.db();
+ let test_db = PostgresTestDb::new(build_background_executor());
+ let db = test_db.db();
// people sign up on the waitlist
for i in 0..8 {
@@ -1144,51 +785,6 @@ async fn test_signups() {
.unwrap_err();
}
-#[tokio::test(flavor = "multi_thread")]
-async fn test_metrics_id() {
- let postgres = TestDb::postgres().await;
- let db = postgres.db();
-
- let NewUserResult {
- user_id: user1,
- metrics_id: metrics_id1,
- ..
- } = db
- .create_user(
- "person1@example.com",
- false,
- NewUserParams {
- github_login: "person1".into(),
- github_user_id: 101,
- invite_count: 5,
- },
- )
- .await
- .unwrap();
- let NewUserResult {
- user_id: user2,
- metrics_id: metrics_id2,
- ..
- } = db
- .create_user(
- "person2@example.com",
- false,
- NewUserParams {
- github_login: "person2".into(),
- github_user_id: 102,
- invite_count: 5,
- },
- )
- .await
- .unwrap();
-
- assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1);
- assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2);
- assert_eq!(metrics_id1.len(), 36);
- assert_eq!(metrics_id2.len(), 36);
- assert_ne!(metrics_id1, metrics_id2);
-}
-
fn build_background_executor() -> Arc<Background> {
Deterministic::new(0).build_background()
}
@@ -1,14 +1,14 @@
use crate::{
- db::{NewUserParams, ProjectId, TestDb, UserId},
- rpc::{Executor, Server, Store},
+ db::{NewUserParams, ProjectId, SqliteTestDb as TestDb, UserId},
+ rpc::{Executor, Server},
AppState,
};
use ::rpc::Peer;
use anyhow::anyhow;
use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{
- self, test::FakeHttpClient, Channel, ChannelDetails, ChannelList, Client, Connection,
- Credentials, EstablishConnectionError, PeerId, User, UserStore, RECEIVE_TIMEOUT,
+ self, test::FakeHttpClient, Client, Connection, Credentials, EstablishConnectionError, PeerId,
+ User, UserStore, RECEIVE_TIMEOUT,
};
use collections::{BTreeMap, HashMap, HashSet};
use editor::{
@@ -16,10 +16,7 @@ use editor::{
ToggleCodeActions, Undo,
};
use fs::{FakeFs, Fs as _, HomeDir, LineEnding};
-use futures::{
- channel::{mpsc, oneshot},
- Future, StreamExt as _,
-};
+use futures::{channel::oneshot, Future, StreamExt as _};
use gpui::{
executor::{self, Deterministic},
geometry::vector::vec2f,
@@ -39,7 +36,6 @@ use project::{
use rand::prelude::*;
use serde_json::json;
use settings::{Formatter, Settings};
-use sqlx::types::time::OffsetDateTime;
use std::{
cell::{Cell, RefCell},
env, mem,
@@ -73,7 +69,10 @@ async fn test_basic_calls(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
+
+ let start = std::time::Instant::now();
+
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -259,6 +258,8 @@ async fn test_basic_calls(
pending: Default::default()
}
);
+
+ eprintln!("finished test {:?}", start.elapsed());
}
#[gpui::test(iterations = 10)]
@@ -271,7 +272,7 @@ async fn test_room_uniqueness(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let _client_a2 = server.create_client(cx_a2, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
@@ -376,7 +377,7 @@ async fn test_leaving_room_on_disconnection(
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -505,7 +506,7 @@ async fn test_calls_on_multiple_connections(
cx_b2: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b1 = server.create_client(cx_b1, "user_b").await;
let client_b2 = server.create_client(cx_b2, "user_b").await;
@@ -654,7 +655,7 @@ async fn test_share_project(
) {
deterministic.forbid_parking();
let (_, window_b) = cx_b.add_window(|_| EmptyView);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -791,7 +792,7 @@ async fn test_unshare_project(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -874,7 +875,7 @@ async fn test_host_disconnect(
) {
cx_b.update(editor::init);
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -979,7 +980,7 @@ async fn test_active_call_events(
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
client_a.fs.insert_tree("/a", json!({})).await;
@@ -1068,7 +1069,7 @@ async fn test_room_location(
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
client_a.fs.insert_tree("/a", json!({})).await;
@@ -1234,7 +1235,7 @@ async fn test_propagate_saves_and_fs_changes(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -1409,7 +1410,7 @@ async fn test_git_diff_base_change(
cx_b: &mut TestAppContext,
) {
executor.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -1661,7 +1662,7 @@ async fn test_fs_operations(
cx_b: &mut TestAppContext,
) {
executor.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -1927,7 +1928,7 @@ async fn test_fs_operations(
#[gpui::test(iterations = 10)]
async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -1981,7 +1982,7 @@ async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut T
#[gpui::test(iterations = 10)]
async fn test_buffer_reloading(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2040,7 +2041,7 @@ async fn test_editing_while_guest_opens_buffer(
cx_b: &mut TestAppContext,
) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2087,7 +2088,7 @@ async fn test_leaving_worktree_while_opening_buffer(
cx_b: &mut TestAppContext,
) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2132,7 +2133,7 @@ async fn test_canceling_buffer_opening(
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2183,7 +2184,7 @@ async fn test_leaving_project(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -2316,7 +2317,7 @@ async fn test_collaborating_with_diagnostics(
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -2581,7 +2582,7 @@ async fn test_collaborating_with_diagnostics(
#[gpui::test(iterations = 10)]
async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2755,7 +2756,7 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu
#[gpui::test(iterations = 10)]
async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2848,7 +2849,7 @@ async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut Te
async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
use project::FormatTrigger;
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -2949,7 +2950,7 @@ async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppCon
#[gpui::test(iterations = 10)]
async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3093,7 +3094,7 @@ async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
#[gpui::test(iterations = 10)]
async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3194,7 +3195,7 @@ async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
#[gpui::test(iterations = 10)]
async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3273,7 +3274,7 @@ async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContex
#[gpui::test(iterations = 10)]
async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3375,7 +3376,7 @@ async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppC
#[gpui::test(iterations = 10)]
async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3478,7 +3479,7 @@ async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
#[gpui::test(iterations = 10)]
async fn test_project_symbols(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3586,7 +3587,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
mut rng: StdRng,
) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3662,7 +3663,7 @@ async fn test_collaborating_with_code_actions(
) {
cx_a.foreground().forbid_parking();
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -3873,7 +3874,7 @@ async fn test_collaborating_with_code_actions(
async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -4065,7 +4066,7 @@ async fn test_language_server_statuses(
deterministic.forbid_parking();
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -4169,415 +4170,6 @@ async fn test_language_server_statuses(
});
}
-#[gpui::test(iterations = 10)]
-async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
- cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
- let client_a = server.create_client(cx_a, "user_a").await;
- let client_b = server.create_client(cx_b, "user_b").await;
-
- // Create an org that includes these 2 users.
- let db = &server.app_state.db;
- let org_id = db.create_org("Test Org", "test-org").await.unwrap();
- db.add_org_member(org_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
- db.add_org_member(org_id, client_b.current_user_id(cx_b), false)
- .await
- .unwrap();
-
- // Create a channel that includes all the users.
- let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
- db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
- db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false)
- .await
- .unwrap();
- db.create_channel_message(
- channel_id,
- client_b.current_user_id(cx_b),
- "hello A, it's B.",
- OffsetDateTime::now_utc(),
- 1,
- )
- .await
- .unwrap();
-
- let channels_a =
- cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx));
- channels_a
- .condition(cx_a, |list, _| list.available_channels().is_some())
- .await;
- channels_a.read_with(cx_a, |list, _| {
- assert_eq!(
- list.available_channels().unwrap(),
- &[ChannelDetails {
- id: channel_id.to_proto(),
- name: "test-channel".to_string()
- }]
- )
- });
- let channel_a = channels_a.update(cx_a, |this, cx| {
- this.get_channel(channel_id.to_proto(), cx).unwrap()
- });
- channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty()));
- channel_a
- .condition(cx_a, |channel, _| {
- channel_messages(channel)
- == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
- })
- .await;
-
- let channels_b =
- cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx));
- channels_b
- .condition(cx_b, |list, _| list.available_channels().is_some())
- .await;
- channels_b.read_with(cx_b, |list, _| {
- assert_eq!(
- list.available_channels().unwrap(),
- &[ChannelDetails {
- id: channel_id.to_proto(),
- name: "test-channel".to_string()
- }]
- )
- });
-
- let channel_b = channels_b.update(cx_b, |this, cx| {
- this.get_channel(channel_id.to_proto(), cx).unwrap()
- });
- channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty()));
- channel_b
- .condition(cx_b, |channel, _| {
- channel_messages(channel)
- == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
- })
- .await;
-
- channel_a
- .update(cx_a, |channel, cx| {
- channel
- .send_message("oh, hi B.".to_string(), cx)
- .unwrap()
- .detach();
- let task = channel.send_message("sup".to_string(), cx).unwrap();
- assert_eq!(
- channel_messages(channel),
- &[
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), true),
- ("user_a".to_string(), "sup".to_string(), true)
- ]
- );
- task
- })
- .await
- .unwrap();
-
- channel_b
- .condition(cx_b, |channel, _| {
- channel_messages(channel)
- == [
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), false),
- ("user_a".to_string(), "sup".to_string(), false),
- ]
- })
- .await;
-
- assert_eq!(
- server
- .store()
- .await
- .channel(channel_id)
- .unwrap()
- .connection_ids
- .len(),
- 2
- );
- cx_b.update(|_| drop(channel_b));
- server
- .condition(|state| state.channel(channel_id).unwrap().connection_ids.len() == 1)
- .await;
-
- cx_a.update(|_| drop(channel_a));
- server
- .condition(|state| state.channel(channel_id).is_none())
- .await;
-}
-
-#[gpui::test(iterations = 10)]
-async fn test_chat_message_validation(cx_a: &mut TestAppContext) {
- cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
- let client_a = server.create_client(cx_a, "user_a").await;
-
- let db = &server.app_state.db;
- let org_id = db.create_org("Test Org", "test-org").await.unwrap();
- let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
- db.add_org_member(org_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
- db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
-
- let channels_a =
- cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx));
- channels_a
- .condition(cx_a, |list, _| list.available_channels().is_some())
- .await;
- let channel_a = channels_a.update(cx_a, |this, cx| {
- this.get_channel(channel_id.to_proto(), cx).unwrap()
- });
-
- // Messages aren't allowed to be too long.
- channel_a
- .update(cx_a, |channel, cx| {
- let long_body = "this is long.\n".repeat(1024);
- channel.send_message(long_body, cx).unwrap()
- })
- .await
- .unwrap_err();
-
- // Messages aren't allowed to be blank.
- channel_a.update(cx_a, |channel, cx| {
- channel.send_message(String::new(), cx).unwrap_err()
- });
-
- // Leading and trailing whitespace are trimmed.
- channel_a
- .update(cx_a, |channel, cx| {
- channel
- .send_message("\n surrounded by whitespace \n".to_string(), cx)
- .unwrap()
- })
- .await
- .unwrap();
- assert_eq!(
- db.get_channel_messages(channel_id, 10, None)
- .await
- .unwrap()
- .iter()
- .map(|m| &m.body)
- .collect::<Vec<_>>(),
- &["surrounded by whitespace"]
- );
-}
-
-#[gpui::test(iterations = 10)]
-async fn test_chat_reconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
- cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
- let client_a = server.create_client(cx_a, "user_a").await;
- let client_b = server.create_client(cx_b, "user_b").await;
-
- let mut status_b = client_b.status();
-
- // Create an org that includes these 2 users.
- let db = &server.app_state.db;
- let org_id = db.create_org("Test Org", "test-org").await.unwrap();
- db.add_org_member(org_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
- db.add_org_member(org_id, client_b.current_user_id(cx_b), false)
- .await
- .unwrap();
-
- // Create a channel that includes all the users.
- let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
- db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false)
- .await
- .unwrap();
- db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false)
- .await
- .unwrap();
- db.create_channel_message(
- channel_id,
- client_b.current_user_id(cx_b),
- "hello A, it's B.",
- OffsetDateTime::now_utc(),
- 2,
- )
- .await
- .unwrap();
-
- let channels_a =
- cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx));
- channels_a
- .condition(cx_a, |list, _| list.available_channels().is_some())
- .await;
-
- channels_a.read_with(cx_a, |list, _| {
- assert_eq!(
- list.available_channels().unwrap(),
- &[ChannelDetails {
- id: channel_id.to_proto(),
- name: "test-channel".to_string()
- }]
- )
- });
- let channel_a = channels_a.update(cx_a, |this, cx| {
- this.get_channel(channel_id.to_proto(), cx).unwrap()
- });
- channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty()));
- channel_a
- .condition(cx_a, |channel, _| {
- channel_messages(channel)
- == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
- })
- .await;
-
- let channels_b =
- cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx));
- channels_b
- .condition(cx_b, |list, _| list.available_channels().is_some())
- .await;
- channels_b.read_with(cx_b, |list, _| {
- assert_eq!(
- list.available_channels().unwrap(),
- &[ChannelDetails {
- id: channel_id.to_proto(),
- name: "test-channel".to_string()
- }]
- )
- });
-
- let channel_b = channels_b.update(cx_b, |this, cx| {
- this.get_channel(channel_id.to_proto(), cx).unwrap()
- });
- channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty()));
- channel_b
- .condition(cx_b, |channel, _| {
- channel_messages(channel)
- == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
- })
- .await;
-
- // Disconnect client B, ensuring we can still access its cached channel data.
- server.forbid_connections();
- server.disconnect_client(client_b.peer_id().unwrap());
- cx_b.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
- while !matches!(
- status_b.next().await,
- Some(client::Status::ReconnectionError { .. })
- ) {}
-
- channels_b.read_with(cx_b, |channels, _| {
- assert_eq!(
- channels.available_channels().unwrap(),
- [ChannelDetails {
- id: channel_id.to_proto(),
- name: "test-channel".to_string()
- }]
- )
- });
- channel_b.read_with(cx_b, |channel, _| {
- assert_eq!(
- channel_messages(channel),
- [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
- )
- });
-
- // Send a message from client B while it is disconnected.
- channel_b
- .update(cx_b, |channel, cx| {
- let task = channel
- .send_message("can you see this?".to_string(), cx)
- .unwrap();
- assert_eq!(
- channel_messages(channel),
- &[
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_b".to_string(), "can you see this?".to_string(), true)
- ]
- );
- task
- })
- .await
- .unwrap_err();
-
- // Send a message from client A while B is disconnected.
- channel_a
- .update(cx_a, |channel, cx| {
- channel
- .send_message("oh, hi B.".to_string(), cx)
- .unwrap()
- .detach();
- let task = channel.send_message("sup".to_string(), cx).unwrap();
- assert_eq!(
- channel_messages(channel),
- &[
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), true),
- ("user_a".to_string(), "sup".to_string(), true)
- ]
- );
- task
- })
- .await
- .unwrap();
-
- // Give client B a chance to reconnect.
- server.allow_connections();
- cx_b.foreground().advance_clock(Duration::from_secs(10));
-
- // Verify that B sees the new messages upon reconnection, as well as the message client B
- // sent while offline.
- channel_b
- .condition(cx_b, |channel, _| {
- channel_messages(channel)
- == [
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), false),
- ("user_a".to_string(), "sup".to_string(), false),
- ("user_b".to_string(), "can you see this?".to_string(), false),
- ]
- })
- .await;
-
- // Ensure client A and B can communicate normally after reconnection.
- channel_a
- .update(cx_a, |channel, cx| {
- channel.send_message("you online?".to_string(), cx).unwrap()
- })
- .await
- .unwrap();
- channel_b
- .condition(cx_b, |channel, _| {
- channel_messages(channel)
- == [
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), false),
- ("user_a".to_string(), "sup".to_string(), false),
- ("user_b".to_string(), "can you see this?".to_string(), false),
- ("user_a".to_string(), "you online?".to_string(), false),
- ]
- })
- .await;
-
- channel_b
- .update(cx_b, |channel, cx| {
- channel.send_message("yep".to_string(), cx).unwrap()
- })
- .await
- .unwrap();
- channel_a
- .condition(cx_a, |channel, _| {
- channel_messages(channel)
- == [
- ("user_b".to_string(), "hello A, it's B.".to_string(), false),
- ("user_a".to_string(), "oh, hi B.".to_string(), false),
- ("user_a".to_string(), "sup".to_string(), false),
- ("user_b".to_string(), "can you see this?".to_string(), false),
- ("user_a".to_string(), "you online?".to_string(), false),
- ("user_b".to_string(), "yep".to_string(), false),
- ]
- })
- .await;
-}
-
#[gpui::test(iterations = 10)]
async fn test_contacts(
deterministic: Arc<Deterministic>,
@@ -4586,7 +4178,7 @@ async fn test_contacts(
cx_c: &mut TestAppContext,
) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
@@ -4912,7 +4504,7 @@ async fn test_contact_requests(
cx_a.foreground().forbid_parking();
// Connect to a server as 3 clients.
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_a2 = server.create_client(cx_a2, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
@@ -5093,7 +4685,7 @@ async fn test_following(
cx_a.update(editor::init);
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -5367,7 +4959,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T
cx_a.update(editor::init);
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -5545,7 +5137,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
cx_b.update(editor::init);
// 2 clients connect to a server.
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -5719,7 +5311,7 @@ async fn test_peers_simultaneously_following_each_other(
cx_a.update(editor::init);
cx_b.update(editor::init);
- let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let mut server = TestServer::start(cx_a.background()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
@@ -5789,7 +5381,7 @@ async fn test_random_collaboration(
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(10);
- let mut server = TestServer::start(cx.foreground(), cx.background()).await;
+ let mut server = TestServer::start(cx.background()).await;
let db = server.app_state.db.clone();
let mut available_guests = Vec::new();
@@ -6076,8 +5668,6 @@ struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
- foreground: Rc<executor::Foreground>,
- notifications: mpsc::UnboundedReceiver<()>,
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
@@ -6085,13 +5675,10 @@ struct TestServer {
}
impl TestServer {
- async fn start(
- foreground: Rc<executor::Foreground>,
- background: Arc<executor::Background>,
- ) -> Self {
+ async fn start(background: Arc<executor::Background>) -> Self {
static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
- let test_db = TestDb::fake(background.clone());
+ let test_db = TestDb::new(background.clone());
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
let live_kit_server = live_kit_client::TestServer::create(
format!("http://livekit.{}.test", live_kit_server_id),
@@ -6102,14 +5689,11 @@ impl TestServer {
.unwrap();
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
let peer = Peer::new();
- let notifications = mpsc::unbounded();
- let server = Server::new(app_state.clone(), Some(notifications.0));
+ let server = Server::new(app_state.clone());
Self {
peer,
app_state,
server,
- foreground,
- notifications: notifications.1,
connection_killers: Default::default(),
forbid_connections: Default::default(),
_test_db: test_db,
@@ -6147,7 +5731,7 @@ impl TestServer {
},
)
.await
- .unwrap()
+ .expect("creating user failed")
.user_id
};
let client_name = name.to_string();
@@ -6187,7 +5771,11 @@ impl TestServer {
let (client_conn, server_conn, killed) =
Connection::in_memory(cx.background());
let (connection_id_tx, connection_id_rx) = oneshot::channel();
- let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+ let user = db
+ .get_user_by_id(user_id)
+ .await
+ .expect("retrieving user failed")
+ .unwrap();
cx.background()
.spawn(server.handle_connection(
server_conn,
@@ -6221,7 +5809,6 @@ impl TestServer {
default_item_factory: |_, _| unimplemented!(),
});
- Channel::init(&client);
Project::init(&client);
cx.update(|cx| {
workspace::init(app_state.clone(), cx);
@@ -6322,21 +5909,6 @@ impl TestServer {
config: Default::default(),
})
}
-
- async fn condition<F>(&mut self, mut predicate: F)
- where
- F: FnMut(&Store) -> bool,
- {
- assert!(
- self.foreground.parking_forbidden(),
- "you must call forbid_parking to use server conditions so we don't block indefinitely"
- );
- while !(predicate)(&*self.server.store.lock().await) {
- self.foreground.start_waiting();
- self.notifications.next().await;
- self.foreground.finish_waiting();
- }
- }
}
impl Deref for TestServer {
@@ -7052,20 +6624,6 @@ impl Executor for Arc<gpui::executor::Background> {
}
}
-fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
- channel
- .messages()
- .cursor::<()>()
- .map(|m| {
- (
- m.sender.github_login.clone(),
- m.body.clone(),
- m.is_pending(),
- )
- })
- .collect()
-}
-
#[derive(Debug, Eq, PartialEq)]
struct RoomParticipants {
remote: Vec<String>,
@@ -13,12 +13,12 @@ use crate::rpc::ResultExt as _;
use anyhow::anyhow;
use axum::{routing::get, Router};
use collab::{Error, Result};
-use db::{Db, PostgresDb};
+use db::DefaultDb as Db;
use serde::Deserialize;
use std::{
env::args,
net::{SocketAddr, TcpListener},
- path::PathBuf,
+ path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
@@ -49,14 +49,14 @@ pub struct MigrateConfig {
}
pub struct AppState {
- db: Arc<dyn Db>,
+ db: Arc<Db>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
config: Config,
}
impl AppState {
async fn new(config: Config) -> Result<Arc<Self>> {
- let db = PostgresDb::new(&config.database_url, 5).await?;
+ let db = Db::new(&config.database_url, 5).await?;
let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server
.as_ref()
@@ -96,13 +96,12 @@ async fn main() -> Result<()> {
}
Some("migrate") => {
let config = envy::from_env::<MigrateConfig>().expect("error loading config");
- let db = PostgresDb::new(&config.database_url, 5).await?;
+ let db = Db::new(&config.database_url, 5).await?;
let migrations_path = config
.migrations_path
.as_deref()
- .or(db::DEFAULT_MIGRATIONS_PATH.map(|s| s.as_ref()))
- .ok_or_else(|| anyhow!("missing MIGRATIONS_PATH environment variable"))?;
+ .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
let migrations = db.migrate(&migrations_path, false).await?;
for (migration, duration) in migrations {
@@ -122,9 +121,7 @@ async fn main() -> Result<()> {
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
.expect("failed to bind TCP listener");
- let rpc_server = rpc::Server::new(state.clone(), None);
- rpc_server
- .start_recording_project_activity(Duration::from_secs(5 * 60), rpc::RealExecutor);
+ let rpc_server = rpc::Server::new(state.clone());
let app = api::routes(rpc_server.clone(), state.clone())
.merge(rpc::routes(rpc_server.clone()))
@@ -2,7 +2,7 @@ mod store;
use crate::{
auth,
- db::{self, ChannelId, MessageId, ProjectId, User, UserId},
+ db::{self, ProjectId, User, UserId},
AppState, Result,
};
use anyhow::anyhow;
@@ -24,7 +24,7 @@ use axum::{
};
use collections::{HashMap, HashSet};
use futures::{
- channel::{mpsc, oneshot},
+ channel::oneshot,
future::{self, BoxFuture},
stream::FuturesUnordered,
FutureExt, SinkExt, StreamExt, TryStreamExt,
@@ -51,7 +51,6 @@ use std::{
time::Duration,
};
pub use store::{Store, Worktree};
-use time::OffsetDateTime;
use tokio::{
sync::{Mutex, MutexGuard},
time::Sleep,
@@ -62,10 +61,6 @@ use tracing::{info_span, instrument, Instrument};
lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge =
register_int_gauge!("connections", "number of connections").unwrap();
- static ref METRIC_REGISTERED_PROJECTS: IntGauge =
- register_int_gauge!("registered_projects", "number of registered projects").unwrap();
- static ref METRIC_ACTIVE_PROJECTS: IntGauge =
- register_int_gauge!("active_projects", "number of active projects").unwrap();
static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
"shared_projects",
"number of open projects with one or more guests"
@@ -95,7 +90,6 @@ pub struct Server {
pub(crate) store: Mutex<Store>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
- notifications: Option<mpsc::UnboundedSender<()>>,
}
pub trait Executor: Send + Clone {
@@ -107,9 +101,6 @@ pub trait Executor: Send + Clone {
#[derive(Clone)]
pub struct RealExecutor;
-const MESSAGE_COUNT_PER_PAGE: usize = 100;
-const MAX_MESSAGE_LEN: usize = 1024;
-
pub(crate) struct StoreGuard<'a> {
guard: MutexGuard<'a, Store>,
_not_send: PhantomData<Rc<()>>,
@@ -132,16 +123,12 @@ where
}
impl Server {
- pub fn new(
- app_state: Arc<AppState>,
- notifications: Option<mpsc::UnboundedSender<()>>,
- ) -> Arc<Self> {
+ pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
let mut server = Self {
peer: Peer::new(),
app_state,
store: Default::default(),
handlers: Default::default(),
- notifications,
};
server
@@ -158,9 +145,7 @@ impl Server {
.add_request_handler(Server::join_project)
.add_message_handler(Server::leave_project)
.add_message_handler(Server::update_project)
- .add_message_handler(Server::register_project_activity)
.add_request_handler(Server::update_worktree)
- .add_message_handler(Server::update_worktree_extensions)
.add_message_handler(Server::start_language_server)
.add_message_handler(Server::update_language_server)
.add_message_handler(Server::update_diagnostic_summary)
@@ -194,19 +179,14 @@ impl Server {
.add_message_handler(Server::buffer_reloaded)
.add_message_handler(Server::buffer_saved)
.add_request_handler(Server::save_buffer)
- .add_request_handler(Server::get_channels)
.add_request_handler(Server::get_users)
.add_request_handler(Server::fuzzy_search_users)
.add_request_handler(Server::request_contact)
.add_request_handler(Server::remove_contact)
.add_request_handler(Server::respond_to_contact_request)
- .add_request_handler(Server::join_channel)
- .add_message_handler(Server::leave_channel)
- .add_request_handler(Server::send_channel_message)
.add_request_handler(Server::follow)
.add_message_handler(Server::unfollow)
.add_message_handler(Server::update_followers)
- .add_request_handler(Server::get_channel_messages)
.add_message_handler(Server::update_diff_base)
.add_request_handler(Server::get_private_user_info);
@@ -290,58 +270,6 @@ impl Server {
})
}
- /// Start a long lived task that records which users are active in which projects.
- pub fn start_recording_project_activity<E: 'static + Executor>(
- self: &Arc<Self>,
- interval: Duration,
- executor: E,
- ) {
- executor.spawn_detached({
- let this = Arc::downgrade(self);
- let executor = executor.clone();
- async move {
- let mut period_start = OffsetDateTime::now_utc();
- let mut active_projects = Vec::<(UserId, ProjectId)>::new();
- loop {
- let sleep = executor.sleep(interval);
- sleep.await;
- let this = if let Some(this) = this.upgrade() {
- this
- } else {
- break;
- };
-
- active_projects.clear();
- active_projects.extend(this.store().await.projects().flat_map(
- |(project_id, project)| {
- project.guests.values().chain([&project.host]).filter_map(
- |collaborator| {
- if !collaborator.admin
- && collaborator
- .last_activity
- .map_or(false, |activity| activity > period_start)
- {
- Some((collaborator.user_id, *project_id))
- } else {
- None
- }
- },
- )
- },
- ));
-
- let period_end = OffsetDateTime::now_utc();
- this.app_state
- .db
- .record_user_activity(period_start..period_end, &active_projects)
- .await
- .trace_err();
- period_start = period_end;
- }
- }
- });
- }
-
pub fn handle_connection<E: Executor>(
self: &Arc<Self>,
connection: Connection,
@@ -432,18 +360,11 @@ impl Server {
let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
- let notifications = this.notifications.clone();
let is_background = message.is_background();
let handle_message = (handler)(this.clone(), message);
-
drop(span_enter);
- let handle_message = async move {
- handle_message.await;
- if let Some(mut notifications) = notifications {
- let _ = notifications.send(()).await;
- }
- }.instrument(span);
+ let handle_message = handle_message.instrument(span);
if is_background {
executor.spawn_detached(handle_message);
} else {
@@ -1172,17 +1093,6 @@ impl Server {
Ok(())
}
- async fn register_project_activity(
- self: Arc<Server>,
- request: TypedEnvelope<proto::RegisterProjectActivity>,
- ) -> Result<()> {
- self.store().await.register_project_activity(
- ProjectId::from_proto(request.payload.project_id),
- request.sender_id,
- )?;
- Ok(())
- }
-
async fn update_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
@@ -1209,25 +1119,6 @@ impl Server {
Ok(())
}
- async fn update_worktree_extensions(
- self: Arc<Server>,
- request: TypedEnvelope<proto::UpdateWorktreeExtensions>,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.payload.project_id);
- let worktree_id = request.payload.worktree_id;
- let extensions = request
- .payload
- .extensions
- .into_iter()
- .zip(request.payload.counts)
- .collect();
- self.app_state
- .db
- .update_worktree_extensions(project_id, worktree_id, extensions)
- .await?;
- Ok(())
- }
-
async fn update_diagnostic_summary(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
@@ -1363,8 +1254,7 @@ impl Server {
) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let receiver_ids = {
- let mut store = self.store().await;
- store.register_project_activity(project_id, request.sender_id)?;
+ let store = self.store().await;
store.project_connection_ids(project_id, request.sender_id)?
};
@@ -1430,15 +1320,13 @@ impl Server {
let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_id;
{
- let mut store = self.store().await;
+ let store = self.store().await;
if !store
.project_connection_ids(project_id, follower_id)?
.contains(&leader_id)
{
Err(anyhow!("no such peer"))?;
}
-
- store.register_project_activity(project_id, follower_id)?;
}
let mut response_payload = self
@@ -1455,14 +1343,13 @@ impl Server {
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let leader_id = ConnectionId(request.payload.leader_id);
- let mut store = self.store().await;
+ let store = self.store().await;
if !store
.project_connection_ids(project_id, request.sender_id)?
.contains(&leader_id)
{
Err(anyhow!("no such peer"))?;
}
- store.register_project_activity(project_id, request.sender_id)?;
self.peer
.forward_send(request.sender_id, leader_id, request.payload)?;
Ok(())
@@ -1473,8 +1360,7 @@ impl Server {
request: TypedEnvelope<proto::UpdateFollowers>,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
- let mut store = self.store().await;
- store.register_project_activity(project_id, request.sender_id)?;
+ let store = self.store().await;
let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
let leader_id = request
.payload
@@ -1495,28 +1381,6 @@ impl Server {
Ok(())
}
- async fn get_channels(
- self: Arc<Server>,
- request: TypedEnvelope<proto::GetChannels>,
- response: Response<proto::GetChannels>,
- ) -> Result<()> {
- let user_id = self
- .store()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channels = self.app_state.db.get_accessible_channels(user_id).await?;
- response.send(proto::GetChannelsResponse {
- channels: channels
- .into_iter()
- .map(|chan| proto::Channel {
- id: chan.id.to_proto(),
- name: chan.name,
- })
- .collect(),
- })?;
- Ok(())
- }
-
async fn get_users(
self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>,
@@ -1712,175 +1576,6 @@ impl Server {
Ok(())
}
- async fn join_channel(
- self: Arc<Self>,
- request: TypedEnvelope<proto::JoinChannel>,
- response: Response<proto::JoinChannel>,
- ) -> Result<()> {
- let user_id = self
- .store()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- if !self
- .app_state
- .db
- .can_user_access_channel(user_id, channel_id)
- .await?
- {
- Err(anyhow!("access denied"))?;
- }
-
- self.store()
- .await
- .join_channel(request.sender_id, channel_id);
- let messages = self
- .app_state
- .db
- .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
- .await?
- .into_iter()
- .map(|msg| proto::ChannelMessage {
- id: msg.id.to_proto(),
- body: msg.body,
- timestamp: msg.sent_at.unix_timestamp() as u64,
- sender_id: msg.sender_id.to_proto(),
- nonce: Some(msg.nonce.as_u128().into()),
- })
- .collect::<Vec<_>>();
- response.send(proto::JoinChannelResponse {
- done: messages.len() < MESSAGE_COUNT_PER_PAGE,
- messages,
- })?;
- Ok(())
- }
-
- async fn leave_channel(
- self: Arc<Self>,
- request: TypedEnvelope<proto::LeaveChannel>,
- ) -> Result<()> {
- let user_id = self
- .store()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- if !self
- .app_state
- .db
- .can_user_access_channel(user_id, channel_id)
- .await?
- {
- Err(anyhow!("access denied"))?;
- }
-
- self.store()
- .await
- .leave_channel(request.sender_id, channel_id);
-
- Ok(())
- }
-
- async fn send_channel_message(
- self: Arc<Self>,
- request: TypedEnvelope<proto::SendChannelMessage>,
- response: Response<proto::SendChannelMessage>,
- ) -> Result<()> {
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- let user_id;
- let connection_ids;
- {
- let state = self.store().await;
- user_id = state.user_id_for_connection(request.sender_id)?;
- connection_ids = state.channel_connection_ids(channel_id)?;
- }
-
- // Validate the message body.
- let body = request.payload.body.trim().to_string();
- if body.len() > MAX_MESSAGE_LEN {
- return Err(anyhow!("message is too long"))?;
- }
- if body.is_empty() {
- return Err(anyhow!("message can't be blank"))?;
- }
-
- let timestamp = OffsetDateTime::now_utc();
- let nonce = request
- .payload
- .nonce
- .ok_or_else(|| anyhow!("nonce can't be blank"))?;
-
- let message_id = self
- .app_state
- .db
- .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
- .await?
- .to_proto();
- let message = proto::ChannelMessage {
- sender_id: user_id.to_proto(),
- id: message_id,
- body,
- timestamp: timestamp.unix_timestamp() as u64,
- nonce: Some(nonce),
- };
- broadcast(request.sender_id, connection_ids, |conn_id| {
- self.peer.send(
- conn_id,
- proto::ChannelMessageSent {
- channel_id: channel_id.to_proto(),
- message: Some(message.clone()),
- },
- )
- });
- response.send(proto::SendChannelMessageResponse {
- message: Some(message),
- })?;
- Ok(())
- }
-
- async fn get_channel_messages(
- self: Arc<Self>,
- request: TypedEnvelope<proto::GetChannelMessages>,
- response: Response<proto::GetChannelMessages>,
- ) -> Result<()> {
- let user_id = self
- .store()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- if !self
- .app_state
- .db
- .can_user_access_channel(user_id, channel_id)
- .await?
- {
- Err(anyhow!("access denied"))?;
- }
-
- let messages = self
- .app_state
- .db
- .get_channel_messages(
- channel_id,
- MESSAGE_COUNT_PER_PAGE,
- Some(MessageId::from_proto(request.payload.before_message_id)),
- )
- .await?
- .into_iter()
- .map(|msg| proto::ChannelMessage {
- id: msg.id.to_proto(),
- body: msg.body,
- timestamp: msg.sent_at.unix_timestamp() as u64,
- sender_id: msg.sender_id.to_proto(),
- nonce: Some(msg.nonce.as_u128().into()),
- })
- .collect::<Vec<_>>();
- response.send(proto::GetChannelMessagesResponse {
- done: messages.len() < MESSAGE_COUNT_PER_PAGE,
- messages,
- })?;
- Ok(())
- }
-
async fn update_diff_base(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateDiffBase>,
@@ -2061,11 +1756,8 @@ pub async fn handle_websocket_request(
}
pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
- // We call `store_mut` here for its side effects of updating metrics.
let metrics = server.store().await.metrics();
METRIC_CONNECTIONS.set(metrics.connections as _);
- METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
- METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
let encoder = prometheus::TextEncoder::new();
@@ -1,11 +1,10 @@
-use crate::db::{self, ChannelId, ProjectId, UserId};
+use crate::db::{self, ProjectId, UserId};
use anyhow::{anyhow, Result};
use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet};
use nanoid::nanoid;
use rpc::{proto, ConnectionId};
use serde::Serialize;
-use std::{borrow::Cow, mem, path::PathBuf, str, time::Duration};
-use time::OffsetDateTime;
+use std::{borrow::Cow, mem, path::PathBuf, str};
use tracing::instrument;
use util::post_inc;
@@ -18,8 +17,6 @@ pub struct Store {
next_room_id: RoomId,
rooms: BTreeMap<RoomId, proto::Room>,
projects: BTreeMap<ProjectId, Project>,
- #[serde(skip)]
- channels: BTreeMap<ChannelId, Channel>,
}
#[derive(Default, Serialize)]
@@ -33,7 +30,6 @@ struct ConnectionState {
user_id: UserId,
admin: bool,
projects: BTreeSet<ProjectId>,
- channels: HashSet<ChannelId>,
}
#[derive(Copy, Clone, Eq, PartialEq, Serialize)]
@@ -60,8 +56,6 @@ pub struct Project {
pub struct Collaborator {
pub replica_id: ReplicaId,
pub user_id: UserId,
- #[serde(skip)]
- pub last_activity: Option<OffsetDateTime>,
pub admin: bool,
}
@@ -78,11 +72,6 @@ pub struct Worktree {
pub is_complete: bool,
}
-#[derive(Default)]
-pub struct Channel {
- pub connection_ids: HashSet<ConnectionId>,
-}
-
pub type ReplicaId = u16;
#[derive(Default)]
@@ -113,38 +102,23 @@ pub struct LeftRoom<'a> {
#[derive(Copy, Clone)]
pub struct Metrics {
pub connections: usize,
- pub registered_projects: usize,
- pub active_projects: usize,
pub shared_projects: usize,
}
impl Store {
pub fn metrics(&self) -> Metrics {
- const ACTIVE_PROJECT_TIMEOUT: Duration = Duration::from_secs(60);
- let active_window_start = OffsetDateTime::now_utc() - ACTIVE_PROJECT_TIMEOUT;
-
let connections = self.connections.values().filter(|c| !c.admin).count();
- let mut registered_projects = 0;
- let mut active_projects = 0;
let mut shared_projects = 0;
for project in self.projects.values() {
if let Some(connection) = self.connections.get(&project.host_connection_id) {
if !connection.admin {
- registered_projects += 1;
- if project.is_active_since(active_window_start) {
- active_projects += 1;
- if !project.guests.is_empty() {
- shared_projects += 1;
- }
- }
+ shared_projects += 1;
}
}
}
Metrics {
connections,
- registered_projects,
- active_projects,
shared_projects,
}
}
@@ -162,7 +136,6 @@ impl Store {
user_id,
admin,
projects: Default::default(),
- channels: Default::default(),
},
);
let connected_user = self.connected_users.entry(user_id).or_default();
@@ -201,18 +174,12 @@ impl Store {
.ok_or_else(|| anyhow!("no such connection"))?;
let user_id = connection.user_id;
- let connection_channels = mem::take(&mut connection.channels);
let mut result = RemovedConnectionState {
user_id,
..Default::default()
};
- // Leave all channels.
- for channel_id in connection_channels {
- self.leave_channel(connection_id, channel_id);
- }
-
let connected_user = self.connected_users.get(&user_id).unwrap();
if let Some(active_call) = connected_user.active_call.as_ref() {
let room_id = active_call.room_id;
@@ -238,34 +205,6 @@ impl Store {
Ok(result)
}
- #[cfg(test)]
- pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
- self.channels.get(&id)
- }
-
- pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
- if let Some(connection) = self.connections.get_mut(&connection_id) {
- connection.channels.insert(channel_id);
- self.channels
- .entry(channel_id)
- .or_default()
- .connection_ids
- .insert(connection_id);
- }
- }
-
- pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
- if let Some(connection) = self.connections.get_mut(&connection_id) {
- connection.channels.remove(&channel_id);
- if let btree_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
- entry.get_mut().connection_ids.remove(&connection_id);
- if entry.get_mut().connection_ids.is_empty() {
- entry.remove();
- }
- }
- }
- }
-
pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
Ok(self
.connections
@@ -760,7 +699,6 @@ impl Store {
host: Collaborator {
user_id: connection.user_id,
replica_id: 0,
- last_activity: None,
admin: connection.admin,
},
guests: Default::default(),
@@ -959,12 +897,10 @@ impl Store {
Collaborator {
replica_id,
user_id: connection.user_id,
- last_activity: Some(OffsetDateTime::now_utc()),
admin: connection.admin,
},
);
- project.host.last_activity = Some(OffsetDateTime::now_utc());
Ok((project, replica_id))
}
@@ -1056,44 +992,12 @@ impl Store {
.connection_ids())
}
- pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
- Ok(self
- .channels
- .get(&channel_id)
- .ok_or_else(|| anyhow!("no such channel"))?
- .connection_ids())
- }
-
pub fn project(&self, project_id: ProjectId) -> Result<&Project> {
self.projects
.get(&project_id)
.ok_or_else(|| anyhow!("no such project"))
}
- pub fn register_project_activity(
- &mut self,
- project_id: ProjectId,
- connection_id: ConnectionId,
- ) -> Result<()> {
- let project = self
- .projects
- .get_mut(&project_id)
- .ok_or_else(|| anyhow!("no such project"))?;
- let collaborator = if connection_id == project.host_connection_id {
- &mut project.host
- } else if let Some(guest) = project.guests.get_mut(&connection_id) {
- guest
- } else {
- return Err(anyhow!("no such project"))?;
- };
- collaborator.last_activity = Some(OffsetDateTime::now_utc());
- Ok(())
- }
-
- pub fn projects(&self) -> impl Iterator<Item = (&ProjectId, &Project)> {
- self.projects.iter()
- }
-
pub fn read_project(
&self,
project_id: ProjectId,
@@ -1154,10 +1058,7 @@ impl Store {
}
}
}
- for channel_id in &connection.channels {
- let channel = self.channels.get(channel_id).unwrap();
- assert!(channel.connection_ids.contains(connection_id));
- }
+
assert!(self
.connected_users
.get(&connection.user_id)
@@ -1253,28 +1154,10 @@ impl Store {
"project was not shared in room"
);
}
-
- for (channel_id, channel) in &self.channels {
- for connection_id in &channel.connection_ids {
- let connection = self.connections.get(connection_id).unwrap();
- assert!(connection.channels.contains(channel_id));
- }
- }
}
}
impl Project {
- fn is_active_since(&self, start_time: OffsetDateTime) -> bool {
- self.guests
- .values()
- .chain([&self.host])
- .any(|collaborator| {
- collaborator
- .last_activity
- .map_or(false, |active_time| active_time > start_time)
- })
- }
-
pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
self.guests.keys().copied().collect()
}
@@ -1287,9 +1170,3 @@ impl Project {
.collect()
}
}
-
-impl Channel {
- fn connection_ids(&self) -> Vec<ConnectionId> {
- self.connection_ids.iter().copied().collect()
- }
-}
@@ -115,7 +115,6 @@ fn main() {
context_menu::init(cx);
project::Project::init(&client);
- client::Channel::init(&client);
client::init(client.clone(), cx);
command_palette::init(cx);
editor::init(cx);