1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19
20 async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
21 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
22 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
23 async fn dismiss_contact_request(
24 &self,
25 responder_id: UserId,
26 requester_id: UserId,
27 ) -> Result<()>;
28 async fn respond_to_contact_request(
29 &self,
30 responder_id: UserId,
31 requester_id: UserId,
32 accept: bool,
33 ) -> Result<()>;
34
35 async fn create_access_token_hash(
36 &self,
37 user_id: UserId,
38 access_token_hash: &str,
39 max_access_token_count: usize,
40 ) -> Result<()>;
41 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
42 #[cfg(any(test, feature = "seed-support"))]
43
44 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
45 #[cfg(any(test, feature = "seed-support"))]
46 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
47 #[cfg(any(test, feature = "seed-support"))]
48 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
49 #[cfg(any(test, feature = "seed-support"))]
50 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
51 #[cfg(any(test, feature = "seed-support"))]
52
53 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
54 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
55 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
56 -> Result<bool>;
57 #[cfg(any(test, feature = "seed-support"))]
58 async fn add_channel_member(
59 &self,
60 channel_id: ChannelId,
61 user_id: UserId,
62 is_admin: bool,
63 ) -> Result<()>;
64 async fn create_channel_message(
65 &self,
66 channel_id: ChannelId,
67 sender_id: UserId,
68 body: &str,
69 timestamp: OffsetDateTime,
70 nonce: u128,
71 ) -> Result<MessageId>;
72 async fn get_channel_messages(
73 &self,
74 channel_id: ChannelId,
75 count: usize,
76 before_id: Option<MessageId>,
77 ) -> Result<Vec<ChannelMessage>>;
78 #[cfg(test)]
79 async fn teardown(&self, url: &str);
80 #[cfg(test)]
81 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
82}
83
84pub struct PostgresDb {
85 pool: sqlx::PgPool,
86}
87
88impl PostgresDb {
89 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
90 let pool = DbOptions::new()
91 .max_connections(max_connections)
92 .connect(&url)
93 .await
94 .context("failed to connect to postgres database")?;
95 Ok(Self { pool })
96 }
97}
98
99#[async_trait]
100impl Db for PostgresDb {
101 // users
102
103 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104 let query = "
105 INSERT INTO users (github_login, admin)
106 VALUES ($1, $2)
107 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
108 RETURNING id
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)?)
116 }
117
118 async fn get_all_users(&self) -> Result<Vec<User>> {
119 let query = "SELECT * FROM users ORDER BY github_login ASC";
120 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
121 }
122
123 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
124 let like_string = fuzzy_like_string(name_query);
125 let query = "
126 SELECT users.*
127 FROM users
128 WHERE github_login like $1
129 ORDER BY github_login <-> $2
130 LIMIT $3
131 ";
132 Ok(sqlx::query_as(query)
133 .bind(like_string)
134 .bind(name_query)
135 .bind(limit)
136 .fetch_all(&self.pool)
137 .await?)
138 }
139
140 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
141 let users = self.get_users_by_ids(vec![id]).await?;
142 Ok(users.into_iter().next())
143 }
144
145 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
146 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
147 let query = "
148 SELECT users.*
149 FROM users
150 WHERE users.id = ANY ($1)
151 ";
152 Ok(sqlx::query_as(query)
153 .bind(&ids)
154 .fetch_all(&self.pool)
155 .await?)
156 }
157
158 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
159 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
160 Ok(sqlx::query_as(query)
161 .bind(github_login)
162 .fetch_optional(&self.pool)
163 .await?)
164 }
165
166 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
167 let query = "UPDATE users SET admin = $1 WHERE id = $2";
168 Ok(sqlx::query(query)
169 .bind(is_admin)
170 .bind(id.0)
171 .execute(&self.pool)
172 .await
173 .map(drop)?)
174 }
175
176 async fn destroy_user(&self, id: UserId) -> Result<()> {
177 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
178 sqlx::query(query)
179 .bind(id.0)
180 .execute(&self.pool)
181 .await
182 .map(drop)?;
183 let query = "DELETE FROM users WHERE id = $1;";
184 Ok(sqlx::query(query)
185 .bind(id.0)
186 .execute(&self.pool)
187 .await
188 .map(drop)?)
189 }
190
191 // contacts
192
193 async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
194 let query = "
195 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
196 FROM contacts
197 WHERE user_id_a = $1 OR user_id_b = $1;
198 ";
199
200 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
201 .bind(user_id)
202 .fetch(&self.pool);
203
204 let mut current = Vec::new();
205 let mut outgoing_requests = Vec::new();
206 let mut incoming_requests = Vec::new();
207 while let Some(row) = rows.next().await {
208 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
209
210 if user_id_a == user_id {
211 if accepted {
212 current.push(user_id_b);
213 } else if a_to_b {
214 outgoing_requests.push(user_id_b);
215 } else {
216 incoming_requests.push(IncomingContactRequest {
217 requester_id: user_id_b,
218 should_notify,
219 });
220 }
221 } else {
222 if accepted {
223 current.push(user_id_a);
224 } else if a_to_b {
225 incoming_requests.push(IncomingContactRequest {
226 requester_id: user_id_a,
227 should_notify,
228 });
229 } else {
230 outgoing_requests.push(user_id_a);
231 }
232 }
233 }
234
235 Ok(Contacts {
236 current,
237 outgoing_requests,
238 incoming_requests,
239 })
240 }
241
242 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
243 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
244 (sender_id, receiver_id, true)
245 } else {
246 (receiver_id, sender_id, false)
247 };
248 let query = "
249 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
250 VALUES ($1, $2, $3, 'f', 't')
251 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
252 SET
253 accepted = 't'
254 WHERE
255 NOT contacts.accepted AND
256 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
257 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
258 ";
259 let result = sqlx::query(query)
260 .bind(id_a.0)
261 .bind(id_b.0)
262 .bind(a_to_b)
263 .execute(&self.pool)
264 .await?;
265
266 if result.rows_affected() == 1 {
267 Ok(())
268 } else {
269 Err(anyhow!("contact already requested"))
270 }
271 }
272
273 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
274 let (id_a, id_b) = if responder_id < requester_id {
275 (responder_id, requester_id)
276 } else {
277 (requester_id, responder_id)
278 };
279 let query = "
280 DELETE FROM contacts
281 WHERE user_id_a = $1 AND user_id_b = $2;
282 ";
283 let result = sqlx::query(query)
284 .bind(id_a.0)
285 .bind(id_b.0)
286 .execute(&self.pool)
287 .await?;
288
289 if result.rows_affected() == 1 {
290 Ok(())
291 } else {
292 Err(anyhow!("no such contact"))
293 }
294 }
295
296 async fn dismiss_contact_request(
297 &self,
298 responder_id: UserId,
299 requester_id: UserId,
300 ) -> Result<()> {
301 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
302 (responder_id, requester_id, false)
303 } else {
304 (requester_id, responder_id, true)
305 };
306
307 let query = "
308 UPDATE contacts
309 SET should_notify = 'f'
310 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
311 ";
312
313 let result = sqlx::query(query)
314 .bind(id_a.0)
315 .bind(id_b.0)
316 .bind(a_to_b)
317 .execute(&self.pool)
318 .await?;
319
320 if result.rows_affected() == 0 {
321 Err(anyhow!("no such contact request"))?;
322 }
323
324 Ok(())
325 }
326
327 async fn respond_to_contact_request(
328 &self,
329 responder_id: UserId,
330 requester_id: UserId,
331 accept: bool,
332 ) -> Result<()> {
333 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
334 (responder_id, requester_id, false)
335 } else {
336 (requester_id, responder_id, true)
337 };
338 let result = if accept {
339 let query = "
340 UPDATE contacts
341 SET accepted = 't', should_notify = 'f'
342 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
343 ";
344 sqlx::query(query)
345 .bind(id_a.0)
346 .bind(id_b.0)
347 .bind(a_to_b)
348 .execute(&self.pool)
349 .await?
350 } else {
351 let query = "
352 DELETE FROM contacts
353 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
354 ";
355 sqlx::query(query)
356 .bind(id_a.0)
357 .bind(id_b.0)
358 .bind(a_to_b)
359 .execute(&self.pool)
360 .await?
361 };
362 if result.rows_affected() == 1 {
363 Ok(())
364 } else {
365 Err(anyhow!("no such contact request"))
366 }
367 }
368
369 // access tokens
370
371 async fn create_access_token_hash(
372 &self,
373 user_id: UserId,
374 access_token_hash: &str,
375 max_access_token_count: usize,
376 ) -> Result<()> {
377 let insert_query = "
378 INSERT INTO access_tokens (user_id, hash)
379 VALUES ($1, $2);
380 ";
381 let cleanup_query = "
382 DELETE FROM access_tokens
383 WHERE id IN (
384 SELECT id from access_tokens
385 WHERE user_id = $1
386 ORDER BY id DESC
387 OFFSET $3
388 )
389 ";
390
391 let mut tx = self.pool.begin().await?;
392 sqlx::query(insert_query)
393 .bind(user_id.0)
394 .bind(access_token_hash)
395 .execute(&mut tx)
396 .await?;
397 sqlx::query(cleanup_query)
398 .bind(user_id.0)
399 .bind(access_token_hash)
400 .bind(max_access_token_count as u32)
401 .execute(&mut tx)
402 .await?;
403 Ok(tx.commit().await?)
404 }
405
406 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
407 let query = "
408 SELECT hash
409 FROM access_tokens
410 WHERE user_id = $1
411 ORDER BY id DESC
412 ";
413 Ok(sqlx::query_scalar(query)
414 .bind(user_id.0)
415 .fetch_all(&self.pool)
416 .await?)
417 }
418
419 // orgs
420
421 #[allow(unused)] // Help rust-analyzer
422 #[cfg(any(test, feature = "seed-support"))]
423 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
424 let query = "
425 SELECT *
426 FROM orgs
427 WHERE slug = $1
428 ";
429 Ok(sqlx::query_as(query)
430 .bind(slug)
431 .fetch_optional(&self.pool)
432 .await?)
433 }
434
435 #[cfg(any(test, feature = "seed-support"))]
436 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
437 let query = "
438 INSERT INTO orgs (name, slug)
439 VALUES ($1, $2)
440 RETURNING id
441 ";
442 Ok(sqlx::query_scalar(query)
443 .bind(name)
444 .bind(slug)
445 .fetch_one(&self.pool)
446 .await
447 .map(OrgId)?)
448 }
449
450 #[cfg(any(test, feature = "seed-support"))]
451 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
452 let query = "
453 INSERT INTO org_memberships (org_id, user_id, admin)
454 VALUES ($1, $2, $3)
455 ON CONFLICT DO NOTHING
456 ";
457 Ok(sqlx::query(query)
458 .bind(org_id.0)
459 .bind(user_id.0)
460 .bind(is_admin)
461 .execute(&self.pool)
462 .await
463 .map(drop)?)
464 }
465
466 // channels
467
468 #[cfg(any(test, feature = "seed-support"))]
469 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
470 let query = "
471 INSERT INTO channels (owner_id, owner_is_user, name)
472 VALUES ($1, false, $2)
473 RETURNING id
474 ";
475 Ok(sqlx::query_scalar(query)
476 .bind(org_id.0)
477 .bind(name)
478 .fetch_one(&self.pool)
479 .await
480 .map(ChannelId)?)
481 }
482
483 #[allow(unused)] // Help rust-analyzer
484 #[cfg(any(test, feature = "seed-support"))]
485 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
486 let query = "
487 SELECT *
488 FROM channels
489 WHERE
490 channels.owner_is_user = false AND
491 channels.owner_id = $1
492 ";
493 Ok(sqlx::query_as(query)
494 .bind(org_id.0)
495 .fetch_all(&self.pool)
496 .await?)
497 }
498
499 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
500 let query = "
501 SELECT
502 channels.*
503 FROM
504 channel_memberships, channels
505 WHERE
506 channel_memberships.user_id = $1 AND
507 channel_memberships.channel_id = channels.id
508 ";
509 Ok(sqlx::query_as(query)
510 .bind(user_id.0)
511 .fetch_all(&self.pool)
512 .await?)
513 }
514
515 async fn can_user_access_channel(
516 &self,
517 user_id: UserId,
518 channel_id: ChannelId,
519 ) -> Result<bool> {
520 let query = "
521 SELECT id
522 FROM channel_memberships
523 WHERE user_id = $1 AND channel_id = $2
524 LIMIT 1
525 ";
526 Ok(sqlx::query_scalar::<_, i32>(query)
527 .bind(user_id.0)
528 .bind(channel_id.0)
529 .fetch_optional(&self.pool)
530 .await
531 .map(|e| e.is_some())?)
532 }
533
534 #[cfg(any(test, feature = "seed-support"))]
535 async fn add_channel_member(
536 &self,
537 channel_id: ChannelId,
538 user_id: UserId,
539 is_admin: bool,
540 ) -> Result<()> {
541 let query = "
542 INSERT INTO channel_memberships (channel_id, user_id, admin)
543 VALUES ($1, $2, $3)
544 ON CONFLICT DO NOTHING
545 ";
546 Ok(sqlx::query(query)
547 .bind(channel_id.0)
548 .bind(user_id.0)
549 .bind(is_admin)
550 .execute(&self.pool)
551 .await
552 .map(drop)?)
553 }
554
555 // messages
556
557 async fn create_channel_message(
558 &self,
559 channel_id: ChannelId,
560 sender_id: UserId,
561 body: &str,
562 timestamp: OffsetDateTime,
563 nonce: u128,
564 ) -> Result<MessageId> {
565 let query = "
566 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
567 VALUES ($1, $2, $3, $4, $5)
568 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
569 RETURNING id
570 ";
571 Ok(sqlx::query_scalar(query)
572 .bind(channel_id.0)
573 .bind(sender_id.0)
574 .bind(body)
575 .bind(timestamp)
576 .bind(Uuid::from_u128(nonce))
577 .fetch_one(&self.pool)
578 .await
579 .map(MessageId)?)
580 }
581
582 async fn get_channel_messages(
583 &self,
584 channel_id: ChannelId,
585 count: usize,
586 before_id: Option<MessageId>,
587 ) -> Result<Vec<ChannelMessage>> {
588 let query = r#"
589 SELECT * FROM (
590 SELECT
591 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
592 FROM
593 channel_messages
594 WHERE
595 channel_id = $1 AND
596 id < $2
597 ORDER BY id DESC
598 LIMIT $3
599 ) as recent_messages
600 ORDER BY id ASC
601 "#;
602 Ok(sqlx::query_as(query)
603 .bind(channel_id.0)
604 .bind(before_id.unwrap_or(MessageId::MAX))
605 .bind(count as i64)
606 .fetch_all(&self.pool)
607 .await?)
608 }
609
610 #[cfg(test)]
611 async fn teardown(&self, url: &str) {
612 use util::ResultExt;
613
614 let query = "
615 SELECT pg_terminate_backend(pg_stat_activity.pid)
616 FROM pg_stat_activity
617 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
618 ";
619 sqlx::query(query).execute(&self.pool).await.log_err();
620 self.pool.close().await;
621 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
622 .await
623 .log_err();
624 }
625
626 #[cfg(test)]
627 fn as_fake(&self) -> Option<&tests::FakeDb> {
628 None
629 }
630}
631
632macro_rules! id_type {
633 ($name:ident) => {
634 #[derive(
635 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
636 )]
637 #[sqlx(transparent)]
638 #[serde(transparent)]
639 pub struct $name(pub i32);
640
641 impl $name {
642 #[allow(unused)]
643 pub const MAX: Self = Self(i32::MAX);
644
645 #[allow(unused)]
646 pub fn from_proto(value: u64) -> Self {
647 Self(value as i32)
648 }
649
650 #[allow(unused)]
651 pub fn to_proto(&self) -> u64 {
652 self.0 as u64
653 }
654 }
655
656 impl std::fmt::Display for $name {
657 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
658 self.0.fmt(f)
659 }
660 }
661 };
662}
663
664id_type!(UserId);
665#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
666pub struct User {
667 pub id: UserId,
668 pub github_login: String,
669 pub admin: bool,
670}
671
672id_type!(OrgId);
673#[derive(FromRow)]
674pub struct Org {
675 pub id: OrgId,
676 pub name: String,
677 pub slug: String,
678}
679
680id_type!(ChannelId);
681#[derive(Clone, Debug, FromRow, Serialize)]
682pub struct Channel {
683 pub id: ChannelId,
684 pub name: String,
685 pub owner_id: i32,
686 pub owner_is_user: bool,
687}
688
689id_type!(MessageId);
690#[derive(Clone, Debug, FromRow)]
691pub struct ChannelMessage {
692 pub id: MessageId,
693 pub channel_id: ChannelId,
694 pub sender_id: UserId,
695 pub body: String,
696 pub sent_at: OffsetDateTime,
697 pub nonce: Uuid,
698}
699
700#[derive(Clone, Debug, PartialEq, Eq)]
701pub struct Contacts {
702 pub current: Vec<UserId>,
703 pub incoming_requests: Vec<IncomingContactRequest>,
704 pub outgoing_requests: Vec<UserId>,
705}
706
707#[derive(Clone, Debug, PartialEq, Eq)]
708pub struct IncomingContactRequest {
709 pub requester_id: UserId,
710 pub should_notify: bool,
711}
712
713fn fuzzy_like_string(string: &str) -> String {
714 let mut result = String::with_capacity(string.len() * 2 + 1);
715 for c in string.chars() {
716 if c.is_alphanumeric() {
717 result.push('%');
718 result.push(c);
719 }
720 }
721 result.push('%');
722 result
723}
724
725#[cfg(test)]
726pub mod tests {
727 use super::*;
728 use anyhow::anyhow;
729 use collections::BTreeMap;
730 use gpui::executor::Background;
731 use lazy_static::lazy_static;
732 use parking_lot::Mutex;
733 use rand::prelude::*;
734 use sqlx::{
735 migrate::{MigrateDatabase, Migrator},
736 Postgres,
737 };
738 use std::{path::Path, sync::Arc};
739 use util::post_inc;
740
741 #[tokio::test(flavor = "multi_thread")]
742 async fn test_get_users_by_ids() {
743 for test_db in [
744 TestDb::postgres().await,
745 TestDb::fake(Arc::new(gpui::executor::Background::new())),
746 ] {
747 let db = test_db.db();
748
749 let user = db.create_user("user", false).await.unwrap();
750 let friend1 = db.create_user("friend-1", false).await.unwrap();
751 let friend2 = db.create_user("friend-2", false).await.unwrap();
752 let friend3 = db.create_user("friend-3", false).await.unwrap();
753
754 assert_eq!(
755 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
756 .await
757 .unwrap(),
758 vec![
759 User {
760 id: user,
761 github_login: "user".to_string(),
762 admin: false,
763 },
764 User {
765 id: friend1,
766 github_login: "friend-1".to_string(),
767 admin: false,
768 },
769 User {
770 id: friend2,
771 github_login: "friend-2".to_string(),
772 admin: false,
773 },
774 User {
775 id: friend3,
776 github_login: "friend-3".to_string(),
777 admin: false,
778 }
779 ]
780 );
781 }
782 }
783
784 #[tokio::test(flavor = "multi_thread")]
785 async fn test_recent_channel_messages() {
786 for test_db in [
787 TestDb::postgres().await,
788 TestDb::fake(Arc::new(gpui::executor::Background::new())),
789 ] {
790 let db = test_db.db();
791 let user = db.create_user("user", false).await.unwrap();
792 let org = db.create_org("org", "org").await.unwrap();
793 let channel = db.create_org_channel(org, "channel").await.unwrap();
794 for i in 0..10 {
795 db.create_channel_message(
796 channel,
797 user,
798 &i.to_string(),
799 OffsetDateTime::now_utc(),
800 i,
801 )
802 .await
803 .unwrap();
804 }
805
806 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
807 assert_eq!(
808 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
809 ["5", "6", "7", "8", "9"]
810 );
811
812 let prev_messages = db
813 .get_channel_messages(channel, 4, Some(messages[0].id))
814 .await
815 .unwrap();
816 assert_eq!(
817 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
818 ["1", "2", "3", "4"]
819 );
820 }
821 }
822
823 #[tokio::test(flavor = "multi_thread")]
824 async fn test_channel_message_nonces() {
825 for test_db in [
826 TestDb::postgres().await,
827 TestDb::fake(Arc::new(gpui::executor::Background::new())),
828 ] {
829 let db = test_db.db();
830 let user = db.create_user("user", false).await.unwrap();
831 let org = db.create_org("org", "org").await.unwrap();
832 let channel = db.create_org_channel(org, "channel").await.unwrap();
833
834 let msg1_id = db
835 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
836 .await
837 .unwrap();
838 let msg2_id = db
839 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
840 .await
841 .unwrap();
842 let msg3_id = db
843 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
844 .await
845 .unwrap();
846 let msg4_id = db
847 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
848 .await
849 .unwrap();
850
851 assert_ne!(msg1_id, msg2_id);
852 assert_eq!(msg1_id, msg3_id);
853 assert_eq!(msg2_id, msg4_id);
854 }
855 }
856
857 #[tokio::test(flavor = "multi_thread")]
858 async fn test_create_access_tokens() {
859 let test_db = TestDb::postgres().await;
860 let db = test_db.db();
861 let user = db.create_user("the-user", false).await.unwrap();
862
863 db.create_access_token_hash(user, "h1", 3).await.unwrap();
864 db.create_access_token_hash(user, "h2", 3).await.unwrap();
865 assert_eq!(
866 db.get_access_token_hashes(user).await.unwrap(),
867 &["h2".to_string(), "h1".to_string()]
868 );
869
870 db.create_access_token_hash(user, "h3", 3).await.unwrap();
871 assert_eq!(
872 db.get_access_token_hashes(user).await.unwrap(),
873 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
874 );
875
876 db.create_access_token_hash(user, "h4", 3).await.unwrap();
877 assert_eq!(
878 db.get_access_token_hashes(user).await.unwrap(),
879 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
880 );
881
882 db.create_access_token_hash(user, "h5", 3).await.unwrap();
883 assert_eq!(
884 db.get_access_token_hashes(user).await.unwrap(),
885 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
886 );
887 }
888
889 #[test]
890 fn test_fuzzy_like_string() {
891 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
892 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
893 assert_eq!(fuzzy_like_string(" z "), "%z%");
894 }
895
896 #[tokio::test(flavor = "multi_thread")]
897 async fn test_fuzzy_search_users() {
898 let test_db = TestDb::postgres().await;
899 let db = test_db.db();
900 for github_login in [
901 "california",
902 "colorado",
903 "oregon",
904 "washington",
905 "florida",
906 "delaware",
907 "rhode-island",
908 ] {
909 db.create_user(github_login, false).await.unwrap();
910 }
911
912 assert_eq!(
913 fuzzy_search_user_names(db, "clr").await,
914 &["colorado", "california"]
915 );
916 assert_eq!(
917 fuzzy_search_user_names(db, "ro").await,
918 &["rhode-island", "colorado", "oregon"],
919 );
920
921 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
922 db.fuzzy_search_users(query, 10)
923 .await
924 .unwrap()
925 .into_iter()
926 .map(|user| user.github_login)
927 .collect::<Vec<_>>()
928 }
929 }
930
931 #[tokio::test(flavor = "multi_thread")]
932 async fn test_add_contacts() {
933 for test_db in [
934 TestDb::postgres().await,
935 TestDb::fake(Arc::new(gpui::executor::Background::new())),
936 ] {
937 let db = test_db.db();
938
939 let user_1 = db.create_user("user1", false).await.unwrap();
940 let user_2 = db.create_user("user2", false).await.unwrap();
941 let user_3 = db.create_user("user3", false).await.unwrap();
942
943 // User starts with no contacts
944 assert_eq!(
945 db.get_contacts(user_1).await.unwrap(),
946 Contacts {
947 current: vec![],
948 outgoing_requests: vec![],
949 incoming_requests: vec![],
950 },
951 );
952
953 // User requests a contact. Both users see the pending request.
954 db.send_contact_request(user_1, user_2).await.unwrap();
955 assert_eq!(
956 db.get_contacts(user_1).await.unwrap(),
957 Contacts {
958 current: vec![],
959 outgoing_requests: vec![user_2],
960 incoming_requests: vec![],
961 },
962 );
963 assert_eq!(
964 db.get_contacts(user_2).await.unwrap(),
965 Contacts {
966 current: vec![],
967 outgoing_requests: vec![],
968 incoming_requests: vec![IncomingContactRequest {
969 requester_id: user_1,
970 should_notify: true
971 }],
972 },
973 );
974
975 // User 2 dismisses the contact request notification without accepting or rejecting.
976 // We shouldn't notify them again.
977 db.dismiss_contact_request(user_1, user_2)
978 .await
979 .unwrap_err();
980 db.dismiss_contact_request(user_2, user_1).await.unwrap();
981 assert_eq!(
982 db.get_contacts(user_2).await.unwrap(),
983 Contacts {
984 current: vec![],
985 outgoing_requests: vec![],
986 incoming_requests: vec![IncomingContactRequest {
987 requester_id: user_1,
988 should_notify: false
989 }],
990 },
991 );
992
993 // User can't accept their own contact request
994 db.respond_to_contact_request(user_1, user_2, true)
995 .await
996 .unwrap_err();
997
998 // User accepts a contact request. Both users see the contact.
999 db.respond_to_contact_request(user_2, user_1, true)
1000 .await
1001 .unwrap();
1002 assert_eq!(
1003 db.get_contacts(user_1).await.unwrap(),
1004 Contacts {
1005 current: vec![user_2],
1006 outgoing_requests: vec![],
1007 incoming_requests: vec![],
1008 },
1009 );
1010 assert_eq!(
1011 db.get_contacts(user_2).await.unwrap(),
1012 Contacts {
1013 current: vec![user_1],
1014 outgoing_requests: vec![],
1015 incoming_requests: vec![],
1016 },
1017 );
1018
1019 // Users cannot re-request existing contacts.
1020 db.send_contact_request(user_1, user_2).await.unwrap_err();
1021 db.send_contact_request(user_2, user_1).await.unwrap_err();
1022
1023 // Users send each other concurrent contact requests and
1024 // see that they are immediately accepted.
1025 db.send_contact_request(user_1, user_3).await.unwrap();
1026 db.send_contact_request(user_3, user_1).await.unwrap();
1027 assert_eq!(
1028 db.get_contacts(user_1).await.unwrap(),
1029 Contacts {
1030 current: vec![user_2, user_3],
1031 outgoing_requests: vec![],
1032 incoming_requests: vec![],
1033 },
1034 );
1035 assert_eq!(
1036 db.get_contacts(user_3).await.unwrap(),
1037 Contacts {
1038 current: vec![user_1],
1039 outgoing_requests: vec![],
1040 incoming_requests: vec![],
1041 },
1042 );
1043
1044 // User declines a contact request. Both users see that it is gone.
1045 db.send_contact_request(user_2, user_3).await.unwrap();
1046 db.respond_to_contact_request(user_3, user_2, false)
1047 .await
1048 .unwrap();
1049 assert_eq!(
1050 db.get_contacts(user_2).await.unwrap(),
1051 Contacts {
1052 current: vec![user_1],
1053 outgoing_requests: vec![],
1054 incoming_requests: vec![],
1055 },
1056 );
1057 assert_eq!(
1058 db.get_contacts(user_3).await.unwrap(),
1059 Contacts {
1060 current: vec![user_1],
1061 outgoing_requests: vec![],
1062 incoming_requests: vec![],
1063 },
1064 );
1065 }
1066 }
1067
1068 pub struct TestDb {
1069 pub db: Option<Arc<dyn Db>>,
1070 pub url: String,
1071 }
1072
1073 impl TestDb {
1074 pub async fn postgres() -> Self {
1075 lazy_static! {
1076 static ref LOCK: Mutex<()> = Mutex::new(());
1077 }
1078
1079 let _guard = LOCK.lock();
1080 let mut rng = StdRng::from_entropy();
1081 let name = format!("zed-test-{}", rng.gen::<u128>());
1082 let url = format!("postgres://postgres@localhost/{}", name);
1083 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1084 Postgres::create_database(&url)
1085 .await
1086 .expect("failed to create test db");
1087 let db = PostgresDb::new(&url, 5).await.unwrap();
1088 let migrator = Migrator::new(migrations_path).await.unwrap();
1089 migrator.run(&db.pool).await.unwrap();
1090 Self {
1091 db: Some(Arc::new(db)),
1092 url,
1093 }
1094 }
1095
1096 pub fn fake(background: Arc<Background>) -> Self {
1097 Self {
1098 db: Some(Arc::new(FakeDb::new(background))),
1099 url: Default::default(),
1100 }
1101 }
1102
1103 pub fn db(&self) -> &Arc<dyn Db> {
1104 self.db.as_ref().unwrap()
1105 }
1106 }
1107
1108 impl Drop for TestDb {
1109 fn drop(&mut self) {
1110 if let Some(db) = self.db.take() {
1111 futures::executor::block_on(db.teardown(&self.url));
1112 }
1113 }
1114 }
1115
1116 pub struct FakeDb {
1117 background: Arc<Background>,
1118 pub users: Mutex<BTreeMap<UserId, User>>,
1119 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1120 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1121 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1122 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1123 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1124 pub contacts: Mutex<Vec<FakeContact>>,
1125 next_channel_message_id: Mutex<i32>,
1126 next_user_id: Mutex<i32>,
1127 next_org_id: Mutex<i32>,
1128 next_channel_id: Mutex<i32>,
1129 }
1130
1131 #[derive(Debug)]
1132 pub struct FakeContact {
1133 pub requester_id: UserId,
1134 pub responder_id: UserId,
1135 pub accepted: bool,
1136 pub should_notify: bool,
1137 }
1138
1139 impl FakeDb {
1140 pub fn new(background: Arc<Background>) -> Self {
1141 Self {
1142 background,
1143 users: Default::default(),
1144 next_user_id: Mutex::new(1),
1145 orgs: Default::default(),
1146 next_org_id: Mutex::new(1),
1147 org_memberships: Default::default(),
1148 channels: Default::default(),
1149 next_channel_id: Mutex::new(1),
1150 channel_memberships: Default::default(),
1151 channel_messages: Default::default(),
1152 next_channel_message_id: Mutex::new(1),
1153 contacts: Default::default(),
1154 }
1155 }
1156 }
1157
1158 #[async_trait]
1159 impl Db for FakeDb {
1160 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1161 self.background.simulate_random_delay().await;
1162
1163 let mut users = self.users.lock();
1164 if let Some(user) = users
1165 .values()
1166 .find(|user| user.github_login == github_login)
1167 {
1168 Ok(user.id)
1169 } else {
1170 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1171 users.insert(
1172 user_id,
1173 User {
1174 id: user_id,
1175 github_login: github_login.to_string(),
1176 admin,
1177 },
1178 );
1179 Ok(user_id)
1180 }
1181 }
1182
1183 async fn get_all_users(&self) -> Result<Vec<User>> {
1184 unimplemented!()
1185 }
1186
1187 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1188 unimplemented!()
1189 }
1190
1191 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1192 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1193 }
1194
1195 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1196 self.background.simulate_random_delay().await;
1197 let users = self.users.lock();
1198 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1199 }
1200
1201 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1202 Ok(self
1203 .users
1204 .lock()
1205 .values()
1206 .find(|user| user.github_login == github_login)
1207 .cloned())
1208 }
1209
1210 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1211 unimplemented!()
1212 }
1213
1214 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1215 unimplemented!()
1216 }
1217
1218 async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
1219 self.background.simulate_random_delay().await;
1220 let mut current = Vec::new();
1221 let mut outgoing_requests = Vec::new();
1222 let mut incoming_requests = Vec::new();
1223
1224 for contact in self.contacts.lock().iter() {
1225 if contact.requester_id == id {
1226 if contact.accepted {
1227 current.push(contact.responder_id);
1228 } else {
1229 outgoing_requests.push(contact.responder_id);
1230 }
1231 } else if contact.responder_id == id {
1232 if contact.accepted {
1233 current.push(contact.requester_id);
1234 } else {
1235 incoming_requests.push(IncomingContactRequest {
1236 requester_id: contact.requester_id,
1237 should_notify: contact.should_notify,
1238 });
1239 }
1240 }
1241 }
1242
1243 Ok(Contacts {
1244 current,
1245 outgoing_requests,
1246 incoming_requests,
1247 })
1248 }
1249
1250 async fn send_contact_request(
1251 &self,
1252 requester_id: UserId,
1253 responder_id: UserId,
1254 ) -> Result<()> {
1255 let mut contacts = self.contacts.lock();
1256 for contact in contacts.iter_mut() {
1257 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1258 if contact.accepted {
1259 Err(anyhow!("contact already exists"))?;
1260 } else {
1261 Err(anyhow!("contact already requested"))?;
1262 }
1263 }
1264 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1265 if contact.accepted {
1266 Err(anyhow!("contact already exists"))?;
1267 } else {
1268 contact.accepted = true;
1269 return Ok(());
1270 }
1271 }
1272 }
1273 contacts.push(FakeContact {
1274 requester_id,
1275 responder_id,
1276 accepted: false,
1277 should_notify: true,
1278 });
1279 Ok(())
1280 }
1281
1282 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1283 self.contacts.lock().retain(|contact| {
1284 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1285 });
1286 Ok(())
1287 }
1288
1289 async fn dismiss_contact_request(
1290 &self,
1291 responder_id: UserId,
1292 requester_id: UserId,
1293 ) -> Result<()> {
1294 let mut contacts = self.contacts.lock();
1295 for contact in contacts.iter_mut() {
1296 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1297 if contact.accepted {
1298 return Err(anyhow!("contact already confirmed"));
1299 }
1300 contact.should_notify = false;
1301 return Ok(());
1302 }
1303 }
1304 Err(anyhow!("no such contact request"))
1305 }
1306
1307 async fn respond_to_contact_request(
1308 &self,
1309 responder_id: UserId,
1310 requester_id: UserId,
1311 accept: bool,
1312 ) -> Result<()> {
1313 let mut contacts = self.contacts.lock();
1314 for (ix, contact) in contacts.iter_mut().enumerate() {
1315 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1316 if contact.accepted {
1317 return Err(anyhow!("contact already confirmed"));
1318 }
1319 if accept {
1320 contact.accepted = true;
1321 } else {
1322 contacts.remove(ix);
1323 }
1324 return Ok(());
1325 }
1326 }
1327 Err(anyhow!("no such contact request"))
1328 }
1329
1330 async fn create_access_token_hash(
1331 &self,
1332 _user_id: UserId,
1333 _access_token_hash: &str,
1334 _max_access_token_count: usize,
1335 ) -> Result<()> {
1336 unimplemented!()
1337 }
1338
1339 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1340 unimplemented!()
1341 }
1342
1343 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1344 unimplemented!()
1345 }
1346
1347 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1348 self.background.simulate_random_delay().await;
1349 let mut orgs = self.orgs.lock();
1350 if orgs.values().any(|org| org.slug == slug) {
1351 Err(anyhow!("org already exists"))
1352 } else {
1353 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1354 orgs.insert(
1355 org_id,
1356 Org {
1357 id: org_id,
1358 name: name.to_string(),
1359 slug: slug.to_string(),
1360 },
1361 );
1362 Ok(org_id)
1363 }
1364 }
1365
1366 async fn add_org_member(
1367 &self,
1368 org_id: OrgId,
1369 user_id: UserId,
1370 is_admin: bool,
1371 ) -> Result<()> {
1372 self.background.simulate_random_delay().await;
1373 if !self.orgs.lock().contains_key(&org_id) {
1374 return Err(anyhow!("org does not exist"));
1375 }
1376 if !self.users.lock().contains_key(&user_id) {
1377 return Err(anyhow!("user does not exist"));
1378 }
1379
1380 self.org_memberships
1381 .lock()
1382 .entry((org_id, user_id))
1383 .or_insert(is_admin);
1384 Ok(())
1385 }
1386
1387 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1388 self.background.simulate_random_delay().await;
1389 if !self.orgs.lock().contains_key(&org_id) {
1390 return Err(anyhow!("org does not exist"));
1391 }
1392
1393 let mut channels = self.channels.lock();
1394 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1395 channels.insert(
1396 channel_id,
1397 Channel {
1398 id: channel_id,
1399 name: name.to_string(),
1400 owner_id: org_id.0,
1401 owner_is_user: false,
1402 },
1403 );
1404 Ok(channel_id)
1405 }
1406
1407 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1408 self.background.simulate_random_delay().await;
1409 Ok(self
1410 .channels
1411 .lock()
1412 .values()
1413 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1414 .cloned()
1415 .collect())
1416 }
1417
1418 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1419 self.background.simulate_random_delay().await;
1420 let channels = self.channels.lock();
1421 let memberships = self.channel_memberships.lock();
1422 Ok(channels
1423 .values()
1424 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1425 .cloned()
1426 .collect())
1427 }
1428
1429 async fn can_user_access_channel(
1430 &self,
1431 user_id: UserId,
1432 channel_id: ChannelId,
1433 ) -> Result<bool> {
1434 self.background.simulate_random_delay().await;
1435 Ok(self
1436 .channel_memberships
1437 .lock()
1438 .contains_key(&(channel_id, user_id)))
1439 }
1440
1441 async fn add_channel_member(
1442 &self,
1443 channel_id: ChannelId,
1444 user_id: UserId,
1445 is_admin: bool,
1446 ) -> Result<()> {
1447 self.background.simulate_random_delay().await;
1448 if !self.channels.lock().contains_key(&channel_id) {
1449 return Err(anyhow!("channel does not exist"));
1450 }
1451 if !self.users.lock().contains_key(&user_id) {
1452 return Err(anyhow!("user does not exist"));
1453 }
1454
1455 self.channel_memberships
1456 .lock()
1457 .entry((channel_id, user_id))
1458 .or_insert(is_admin);
1459 Ok(())
1460 }
1461
1462 async fn create_channel_message(
1463 &self,
1464 channel_id: ChannelId,
1465 sender_id: UserId,
1466 body: &str,
1467 timestamp: OffsetDateTime,
1468 nonce: u128,
1469 ) -> Result<MessageId> {
1470 self.background.simulate_random_delay().await;
1471 if !self.channels.lock().contains_key(&channel_id) {
1472 return Err(anyhow!("channel does not exist"));
1473 }
1474 if !self.users.lock().contains_key(&sender_id) {
1475 return Err(anyhow!("user does not exist"));
1476 }
1477
1478 let mut messages = self.channel_messages.lock();
1479 if let Some(message) = messages
1480 .values()
1481 .find(|message| message.nonce.as_u128() == nonce)
1482 {
1483 Ok(message.id)
1484 } else {
1485 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1486 messages.insert(
1487 message_id,
1488 ChannelMessage {
1489 id: message_id,
1490 channel_id,
1491 sender_id,
1492 body: body.to_string(),
1493 sent_at: timestamp,
1494 nonce: Uuid::from_u128(nonce),
1495 },
1496 );
1497 Ok(message_id)
1498 }
1499 }
1500
1501 async fn get_channel_messages(
1502 &self,
1503 channel_id: ChannelId,
1504 count: usize,
1505 before_id: Option<MessageId>,
1506 ) -> Result<Vec<ChannelMessage>> {
1507 let mut messages = self
1508 .channel_messages
1509 .lock()
1510 .values()
1511 .rev()
1512 .filter(|message| {
1513 message.channel_id == channel_id
1514 && message.id < before_id.unwrap_or(MessageId::MAX)
1515 })
1516 .take(count)
1517 .cloned()
1518 .collect::<Vec<_>>();
1519 messages.sort_unstable_by_key(|message| message.id);
1520 Ok(messages)
1521 }
1522
1523 async fn teardown(&self, _: &str) {}
1524
1525 #[cfg(test)]
1526 fn as_fake(&self) -> Option<&FakeDb> {
1527 Some(self)
1528 }
1529 }
1530}