1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19
20 async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
21 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
22 async fn dismiss_contact_request(
23 &self,
24 responder_id: UserId,
25 requester_id: UserId,
26 ) -> Result<()>;
27 async fn respond_to_contact_request(
28 &self,
29 responder_id: UserId,
30 requester_id: UserId,
31 accept: bool,
32 ) -> Result<()>;
33
34 async fn create_access_token_hash(
35 &self,
36 user_id: UserId,
37 access_token_hash: &str,
38 max_access_token_count: usize,
39 ) -> Result<()>;
40 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
41 #[cfg(any(test, feature = "seed-support"))]
42
43 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
44 #[cfg(any(test, feature = "seed-support"))]
45 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
46 #[cfg(any(test, feature = "seed-support"))]
47 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
48 #[cfg(any(test, feature = "seed-support"))]
49 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
50 #[cfg(any(test, feature = "seed-support"))]
51
52 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
53 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
54 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
55 -> Result<bool>;
56 #[cfg(any(test, feature = "seed-support"))]
57 async fn add_channel_member(
58 &self,
59 channel_id: ChannelId,
60 user_id: UserId,
61 is_admin: bool,
62 ) -> Result<()>;
63 async fn create_channel_message(
64 &self,
65 channel_id: ChannelId,
66 sender_id: UserId,
67 body: &str,
68 timestamp: OffsetDateTime,
69 nonce: u128,
70 ) -> Result<MessageId>;
71 async fn get_channel_messages(
72 &self,
73 channel_id: ChannelId,
74 count: usize,
75 before_id: Option<MessageId>,
76 ) -> Result<Vec<ChannelMessage>>;
77 #[cfg(test)]
78 async fn teardown(&self, url: &str);
79}
80
81pub struct PostgresDb {
82 pool: sqlx::PgPool,
83}
84
85impl PostgresDb {
86 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
87 let pool = DbOptions::new()
88 .max_connections(max_connections)
89 .connect(&url)
90 .await
91 .context("failed to connect to postgres database")?;
92 Ok(Self { pool })
93 }
94}
95
96#[async_trait]
97impl Db for PostgresDb {
98 // users
99
100 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101 let query = "
102 INSERT INTO users (github_login, admin)
103 VALUES ($1, $2)
104 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
105 RETURNING id
106 ";
107 Ok(sqlx::query_scalar(query)
108 .bind(github_login)
109 .bind(admin)
110 .fetch_one(&self.pool)
111 .await
112 .map(UserId)?)
113 }
114
115 async fn get_all_users(&self) -> Result<Vec<User>> {
116 let query = "SELECT * FROM users ORDER BY github_login ASC";
117 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
118 }
119
120 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
121 let like_string = fuzzy_like_string(name_query);
122 let query = "
123 SELECT users.*
124 FROM users
125 WHERE github_login like $1
126 ORDER BY github_login <-> $2
127 LIMIT $3
128 ";
129 Ok(sqlx::query_as(query)
130 .bind(like_string)
131 .bind(name_query)
132 .bind(limit)
133 .fetch_all(&self.pool)
134 .await?)
135 }
136
137 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
138 let users = self.get_users_by_ids(vec![id]).await?;
139 Ok(users.into_iter().next())
140 }
141
142 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
143 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
144 let query = "
145 SELECT users.*
146 FROM users
147 WHERE users.id = ANY ($1)
148 ";
149 Ok(sqlx::query_as(query)
150 .bind(&ids)
151 .fetch_all(&self.pool)
152 .await?)
153 }
154
155 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
156 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
157 Ok(sqlx::query_as(query)
158 .bind(github_login)
159 .fetch_optional(&self.pool)
160 .await?)
161 }
162
163 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
164 let query = "UPDATE users SET admin = $1 WHERE id = $2";
165 Ok(sqlx::query(query)
166 .bind(is_admin)
167 .bind(id.0)
168 .execute(&self.pool)
169 .await
170 .map(drop)?)
171 }
172
173 async fn destroy_user(&self, id: UserId) -> Result<()> {
174 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
175 sqlx::query(query)
176 .bind(id.0)
177 .execute(&self.pool)
178 .await
179 .map(drop)?;
180 let query = "DELETE FROM users WHERE id = $1;";
181 Ok(sqlx::query(query)
182 .bind(id.0)
183 .execute(&self.pool)
184 .await
185 .map(drop)?)
186 }
187
188 // contacts
189
190 async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
191 let query = "
192 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
193 FROM contacts
194 WHERE user_id_a = $1 OR user_id_b = $1;
195 ";
196
197 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
198 .bind(user_id)
199 .fetch(&self.pool);
200
201 let mut current = Vec::new();
202 let mut requests_sent = Vec::new();
203 let mut requests_received = Vec::new();
204 while let Some(row) = rows.next().await {
205 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
206
207 if user_id_a == user_id {
208 if accepted {
209 current.push(user_id_b);
210 } else if a_to_b {
211 requests_sent.push(user_id_b);
212 } else {
213 requests_received.push(IncomingContactRequest {
214 requesting_user_id: user_id_b,
215 should_notify,
216 });
217 }
218 } else {
219 if accepted {
220 current.push(user_id_a);
221 } else if a_to_b {
222 requests_received.push(IncomingContactRequest {
223 requesting_user_id: user_id_a,
224 should_notify,
225 });
226 } else {
227 requests_sent.push(user_id_a);
228 }
229 }
230 }
231
232 Ok(Contacts {
233 current,
234 requests_sent,
235 requests_received,
236 })
237 }
238
239 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
240 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
241 (sender_id, receiver_id, true)
242 } else {
243 (receiver_id, sender_id, false)
244 };
245 let query = "
246 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
247 VALUES ($1, $2, $3, 'f', 't')
248 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
249 SET
250 accepted = 't'
251 WHERE
252 NOT contacts.accepted AND
253 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
254 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
255 ";
256 let result = sqlx::query(query)
257 .bind(id_a.0)
258 .bind(id_b.0)
259 .bind(a_to_b)
260 .execute(&self.pool)
261 .await?;
262
263 if result.rows_affected() == 1 {
264 Ok(())
265 } else {
266 Err(anyhow!("contact already requested"))
267 }
268 }
269
270 async fn respond_to_contact_request(
271 &self,
272 responder_id: UserId,
273 requester_id: UserId,
274 accept: bool,
275 ) -> Result<()> {
276 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
277 (responder_id, requester_id, false)
278 } else {
279 (requester_id, responder_id, true)
280 };
281 let result = if accept {
282 let query = "
283 UPDATE contacts
284 SET accepted = 't', should_notify = 'f'
285 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
286 ";
287 sqlx::query(query)
288 .bind(id_a.0)
289 .bind(id_b.0)
290 .bind(a_to_b)
291 .execute(&self.pool)
292 .await?
293 } else {
294 let query = "
295 DELETE FROM contacts
296 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
297 ";
298 sqlx::query(query)
299 .bind(id_a.0)
300 .bind(id_b.0)
301 .bind(a_to_b)
302 .execute(&self.pool)
303 .await?
304 };
305 if result.rows_affected() == 1 {
306 Ok(())
307 } else {
308 Err(anyhow!("no such contact request"))
309 }
310 }
311
312 async fn dismiss_contact_request(
313 &self,
314 responder_id: UserId,
315 requester_id: UserId,
316 ) -> Result<()> {
317 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
318 (responder_id, requester_id, false)
319 } else {
320 (requester_id, responder_id, true)
321 };
322
323 let query = "
324 UPDATE contacts
325 SET should_notify = 'f'
326 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
327 ";
328
329 let result = sqlx::query(query)
330 .bind(id_a.0)
331 .bind(id_b.0)
332 .bind(a_to_b)
333 .execute(&self.pool)
334 .await?;
335
336 if result.rows_affected() == 0 {
337 Err(anyhow!("no such contact request"))?;
338 }
339
340 Ok(())
341 }
342
343 // access tokens
344
345 async fn create_access_token_hash(
346 &self,
347 user_id: UserId,
348 access_token_hash: &str,
349 max_access_token_count: usize,
350 ) -> Result<()> {
351 let insert_query = "
352 INSERT INTO access_tokens (user_id, hash)
353 VALUES ($1, $2);
354 ";
355 let cleanup_query = "
356 DELETE FROM access_tokens
357 WHERE id IN (
358 SELECT id from access_tokens
359 WHERE user_id = $1
360 ORDER BY id DESC
361 OFFSET $3
362 )
363 ";
364
365 let mut tx = self.pool.begin().await?;
366 sqlx::query(insert_query)
367 .bind(user_id.0)
368 .bind(access_token_hash)
369 .execute(&mut tx)
370 .await?;
371 sqlx::query(cleanup_query)
372 .bind(user_id.0)
373 .bind(access_token_hash)
374 .bind(max_access_token_count as u32)
375 .execute(&mut tx)
376 .await?;
377 Ok(tx.commit().await?)
378 }
379
380 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
381 let query = "
382 SELECT hash
383 FROM access_tokens
384 WHERE user_id = $1
385 ORDER BY id DESC
386 ";
387 Ok(sqlx::query_scalar(query)
388 .bind(user_id.0)
389 .fetch_all(&self.pool)
390 .await?)
391 }
392
393 // orgs
394
395 #[allow(unused)] // Help rust-analyzer
396 #[cfg(any(test, feature = "seed-support"))]
397 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
398 let query = "
399 SELECT *
400 FROM orgs
401 WHERE slug = $1
402 ";
403 Ok(sqlx::query_as(query)
404 .bind(slug)
405 .fetch_optional(&self.pool)
406 .await?)
407 }
408
409 #[cfg(any(test, feature = "seed-support"))]
410 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
411 let query = "
412 INSERT INTO orgs (name, slug)
413 VALUES ($1, $2)
414 RETURNING id
415 ";
416 Ok(sqlx::query_scalar(query)
417 .bind(name)
418 .bind(slug)
419 .fetch_one(&self.pool)
420 .await
421 .map(OrgId)?)
422 }
423
424 #[cfg(any(test, feature = "seed-support"))]
425 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
426 let query = "
427 INSERT INTO org_memberships (org_id, user_id, admin)
428 VALUES ($1, $2, $3)
429 ON CONFLICT DO NOTHING
430 ";
431 Ok(sqlx::query(query)
432 .bind(org_id.0)
433 .bind(user_id.0)
434 .bind(is_admin)
435 .execute(&self.pool)
436 .await
437 .map(drop)?)
438 }
439
440 // channels
441
442 #[cfg(any(test, feature = "seed-support"))]
443 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
444 let query = "
445 INSERT INTO channels (owner_id, owner_is_user, name)
446 VALUES ($1, false, $2)
447 RETURNING id
448 ";
449 Ok(sqlx::query_scalar(query)
450 .bind(org_id.0)
451 .bind(name)
452 .fetch_one(&self.pool)
453 .await
454 .map(ChannelId)?)
455 }
456
457 #[allow(unused)] // Help rust-analyzer
458 #[cfg(any(test, feature = "seed-support"))]
459 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
460 let query = "
461 SELECT *
462 FROM channels
463 WHERE
464 channels.owner_is_user = false AND
465 channels.owner_id = $1
466 ";
467 Ok(sqlx::query_as(query)
468 .bind(org_id.0)
469 .fetch_all(&self.pool)
470 .await?)
471 }
472
473 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
474 let query = "
475 SELECT
476 channels.*
477 FROM
478 channel_memberships, channels
479 WHERE
480 channel_memberships.user_id = $1 AND
481 channel_memberships.channel_id = channels.id
482 ";
483 Ok(sqlx::query_as(query)
484 .bind(user_id.0)
485 .fetch_all(&self.pool)
486 .await?)
487 }
488
489 async fn can_user_access_channel(
490 &self,
491 user_id: UserId,
492 channel_id: ChannelId,
493 ) -> Result<bool> {
494 let query = "
495 SELECT id
496 FROM channel_memberships
497 WHERE user_id = $1 AND channel_id = $2
498 LIMIT 1
499 ";
500 Ok(sqlx::query_scalar::<_, i32>(query)
501 .bind(user_id.0)
502 .bind(channel_id.0)
503 .fetch_optional(&self.pool)
504 .await
505 .map(|e| e.is_some())?)
506 }
507
508 #[cfg(any(test, feature = "seed-support"))]
509 async fn add_channel_member(
510 &self,
511 channel_id: ChannelId,
512 user_id: UserId,
513 is_admin: bool,
514 ) -> Result<()> {
515 let query = "
516 INSERT INTO channel_memberships (channel_id, user_id, admin)
517 VALUES ($1, $2, $3)
518 ON CONFLICT DO NOTHING
519 ";
520 Ok(sqlx::query(query)
521 .bind(channel_id.0)
522 .bind(user_id.0)
523 .bind(is_admin)
524 .execute(&self.pool)
525 .await
526 .map(drop)?)
527 }
528
529 // messages
530
531 async fn create_channel_message(
532 &self,
533 channel_id: ChannelId,
534 sender_id: UserId,
535 body: &str,
536 timestamp: OffsetDateTime,
537 nonce: u128,
538 ) -> Result<MessageId> {
539 let query = "
540 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
541 VALUES ($1, $2, $3, $4, $5)
542 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
543 RETURNING id
544 ";
545 Ok(sqlx::query_scalar(query)
546 .bind(channel_id.0)
547 .bind(sender_id.0)
548 .bind(body)
549 .bind(timestamp)
550 .bind(Uuid::from_u128(nonce))
551 .fetch_one(&self.pool)
552 .await
553 .map(MessageId)?)
554 }
555
556 async fn get_channel_messages(
557 &self,
558 channel_id: ChannelId,
559 count: usize,
560 before_id: Option<MessageId>,
561 ) -> Result<Vec<ChannelMessage>> {
562 let query = r#"
563 SELECT * FROM (
564 SELECT
565 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
566 FROM
567 channel_messages
568 WHERE
569 channel_id = $1 AND
570 id < $2
571 ORDER BY id DESC
572 LIMIT $3
573 ) as recent_messages
574 ORDER BY id ASC
575 "#;
576 Ok(sqlx::query_as(query)
577 .bind(channel_id.0)
578 .bind(before_id.unwrap_or(MessageId::MAX))
579 .bind(count as i64)
580 .fetch_all(&self.pool)
581 .await?)
582 }
583
584 #[cfg(test)]
585 async fn teardown(&self, url: &str) {
586 use util::ResultExt;
587
588 let query = "
589 SELECT pg_terminate_backend(pg_stat_activity.pid)
590 FROM pg_stat_activity
591 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
592 ";
593 sqlx::query(query).execute(&self.pool).await.log_err();
594 self.pool.close().await;
595 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
596 .await
597 .log_err();
598 }
599}
600
601macro_rules! id_type {
602 ($name:ident) => {
603 #[derive(
604 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
605 )]
606 #[sqlx(transparent)]
607 #[serde(transparent)]
608 pub struct $name(pub i32);
609
610 impl $name {
611 #[allow(unused)]
612 pub const MAX: Self = Self(i32::MAX);
613
614 #[allow(unused)]
615 pub fn from_proto(value: u64) -> Self {
616 Self(value as i32)
617 }
618
619 #[allow(unused)]
620 pub fn to_proto(&self) -> u64 {
621 self.0 as u64
622 }
623 }
624
625 impl std::fmt::Display for $name {
626 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
627 self.0.fmt(f)
628 }
629 }
630 };
631}
632
633id_type!(UserId);
634#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
635pub struct User {
636 pub id: UserId,
637 pub github_login: String,
638 pub admin: bool,
639}
640
641id_type!(OrgId);
642#[derive(FromRow)]
643pub struct Org {
644 pub id: OrgId,
645 pub name: String,
646 pub slug: String,
647}
648
649id_type!(ChannelId);
650#[derive(Clone, Debug, FromRow, Serialize)]
651pub struct Channel {
652 pub id: ChannelId,
653 pub name: String,
654 pub owner_id: i32,
655 pub owner_is_user: bool,
656}
657
658id_type!(MessageId);
659#[derive(Clone, Debug, FromRow)]
660pub struct ChannelMessage {
661 pub id: MessageId,
662 pub channel_id: ChannelId,
663 pub sender_id: UserId,
664 pub body: String,
665 pub sent_at: OffsetDateTime,
666 pub nonce: Uuid,
667}
668
669#[derive(Clone, Debug, PartialEq, Eq)]
670pub struct Contacts {
671 pub current: Vec<UserId>,
672 pub requests_sent: Vec<UserId>,
673 pub requests_received: Vec<IncomingContactRequest>,
674}
675
676#[derive(Clone, Debug, PartialEq, Eq)]
677pub struct IncomingContactRequest {
678 pub requesting_user_id: UserId,
679 pub should_notify: bool,
680}
681
682fn fuzzy_like_string(string: &str) -> String {
683 let mut result = String::with_capacity(string.len() * 2 + 1);
684 for c in string.chars() {
685 if c.is_alphanumeric() {
686 result.push('%');
687 result.push(c);
688 }
689 }
690 result.push('%');
691 result
692}
693
694#[cfg(test)]
695pub mod tests {
696 use super::*;
697 use anyhow::anyhow;
698 use collections::BTreeMap;
699 use gpui::executor::Background;
700 use lazy_static::lazy_static;
701 use parking_lot::Mutex;
702 use rand::prelude::*;
703 use sqlx::{
704 migrate::{MigrateDatabase, Migrator},
705 Postgres,
706 };
707 use std::{path::Path, sync::Arc};
708 use util::post_inc;
709
710 #[tokio::test(flavor = "multi_thread")]
711 async fn test_get_users_by_ids() {
712 for test_db in [
713 TestDb::postgres().await,
714 TestDb::fake(Arc::new(gpui::executor::Background::new())),
715 ] {
716 let db = test_db.db();
717
718 let user = db.create_user("user", false).await.unwrap();
719 let friend1 = db.create_user("friend-1", false).await.unwrap();
720 let friend2 = db.create_user("friend-2", false).await.unwrap();
721 let friend3 = db.create_user("friend-3", false).await.unwrap();
722
723 assert_eq!(
724 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
725 .await
726 .unwrap(),
727 vec![
728 User {
729 id: user,
730 github_login: "user".to_string(),
731 admin: false,
732 },
733 User {
734 id: friend1,
735 github_login: "friend-1".to_string(),
736 admin: false,
737 },
738 User {
739 id: friend2,
740 github_login: "friend-2".to_string(),
741 admin: false,
742 },
743 User {
744 id: friend3,
745 github_login: "friend-3".to_string(),
746 admin: false,
747 }
748 ]
749 );
750 }
751 }
752
753 #[tokio::test(flavor = "multi_thread")]
754 async fn test_recent_channel_messages() {
755 for test_db in [
756 TestDb::postgres().await,
757 TestDb::fake(Arc::new(gpui::executor::Background::new())),
758 ] {
759 let db = test_db.db();
760 let user = db.create_user("user", false).await.unwrap();
761 let org = db.create_org("org", "org").await.unwrap();
762 let channel = db.create_org_channel(org, "channel").await.unwrap();
763 for i in 0..10 {
764 db.create_channel_message(
765 channel,
766 user,
767 &i.to_string(),
768 OffsetDateTime::now_utc(),
769 i,
770 )
771 .await
772 .unwrap();
773 }
774
775 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
776 assert_eq!(
777 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
778 ["5", "6", "7", "8", "9"]
779 );
780
781 let prev_messages = db
782 .get_channel_messages(channel, 4, Some(messages[0].id))
783 .await
784 .unwrap();
785 assert_eq!(
786 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
787 ["1", "2", "3", "4"]
788 );
789 }
790 }
791
792 #[tokio::test(flavor = "multi_thread")]
793 async fn test_channel_message_nonces() {
794 for test_db in [
795 TestDb::postgres().await,
796 TestDb::fake(Arc::new(gpui::executor::Background::new())),
797 ] {
798 let db = test_db.db();
799 let user = db.create_user("user", false).await.unwrap();
800 let org = db.create_org("org", "org").await.unwrap();
801 let channel = db.create_org_channel(org, "channel").await.unwrap();
802
803 let msg1_id = db
804 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
805 .await
806 .unwrap();
807 let msg2_id = db
808 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
809 .await
810 .unwrap();
811 let msg3_id = db
812 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
813 .await
814 .unwrap();
815 let msg4_id = db
816 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
817 .await
818 .unwrap();
819
820 assert_ne!(msg1_id, msg2_id);
821 assert_eq!(msg1_id, msg3_id);
822 assert_eq!(msg2_id, msg4_id);
823 }
824 }
825
826 #[tokio::test(flavor = "multi_thread")]
827 async fn test_create_access_tokens() {
828 let test_db = TestDb::postgres().await;
829 let db = test_db.db();
830 let user = db.create_user("the-user", false).await.unwrap();
831
832 db.create_access_token_hash(user, "h1", 3).await.unwrap();
833 db.create_access_token_hash(user, "h2", 3).await.unwrap();
834 assert_eq!(
835 db.get_access_token_hashes(user).await.unwrap(),
836 &["h2".to_string(), "h1".to_string()]
837 );
838
839 db.create_access_token_hash(user, "h3", 3).await.unwrap();
840 assert_eq!(
841 db.get_access_token_hashes(user).await.unwrap(),
842 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
843 );
844
845 db.create_access_token_hash(user, "h4", 3).await.unwrap();
846 assert_eq!(
847 db.get_access_token_hashes(user).await.unwrap(),
848 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
849 );
850
851 db.create_access_token_hash(user, "h5", 3).await.unwrap();
852 assert_eq!(
853 db.get_access_token_hashes(user).await.unwrap(),
854 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
855 );
856 }
857
858 #[test]
859 fn test_fuzzy_like_string() {
860 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
861 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
862 assert_eq!(fuzzy_like_string(" z "), "%z%");
863 }
864
865 #[tokio::test(flavor = "multi_thread")]
866 async fn test_fuzzy_search_users() {
867 let test_db = TestDb::postgres().await;
868 let db = test_db.db();
869 for github_login in [
870 "california",
871 "colorado",
872 "oregon",
873 "washington",
874 "florida",
875 "delaware",
876 "rhode-island",
877 ] {
878 db.create_user(github_login, false).await.unwrap();
879 }
880
881 assert_eq!(
882 fuzzy_search_user_names(db, "clr").await,
883 &["colorado", "california"]
884 );
885 assert_eq!(
886 fuzzy_search_user_names(db, "ro").await,
887 &["rhode-island", "colorado", "oregon"],
888 );
889
890 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
891 db.fuzzy_search_users(query, 10)
892 .await
893 .unwrap()
894 .into_iter()
895 .map(|user| user.github_login)
896 .collect::<Vec<_>>()
897 }
898 }
899
900 #[tokio::test(flavor = "multi_thread")]
901 async fn test_add_contacts() {
902 for test_db in [
903 TestDb::postgres().await,
904 TestDb::fake(Arc::new(gpui::executor::Background::new())),
905 ] {
906 let db = test_db.db();
907
908 let user_1 = db.create_user("user1", false).await.unwrap();
909 let user_2 = db.create_user("user2", false).await.unwrap();
910 let user_3 = db.create_user("user3", false).await.unwrap();
911
912 // User starts with no contacts
913 assert_eq!(
914 db.get_contacts(user_1).await.unwrap(),
915 Contacts {
916 current: vec![],
917 requests_sent: vec![],
918 requests_received: vec![],
919 },
920 );
921
922 // User requests a contact. Both users see the pending request.
923 db.send_contact_request(user_1, user_2).await.unwrap();
924 assert_eq!(
925 db.get_contacts(user_1).await.unwrap(),
926 Contacts {
927 current: vec![],
928 requests_sent: vec![user_2],
929 requests_received: vec![],
930 },
931 );
932 assert_eq!(
933 db.get_contacts(user_2).await.unwrap(),
934 Contacts {
935 current: vec![],
936 requests_sent: vec![],
937 requests_received: vec![IncomingContactRequest {
938 requesting_user_id: user_1,
939 should_notify: true
940 }],
941 },
942 );
943
944 // User 2 dismisses the contact request notification without accepting or rejecting.
945 // We shouldn't notify them again.
946 db.dismiss_contact_request(user_1, user_2)
947 .await
948 .unwrap_err();
949 db.dismiss_contact_request(user_2, user_1).await.unwrap();
950 assert_eq!(
951 db.get_contacts(user_2).await.unwrap(),
952 Contacts {
953 current: vec![],
954 requests_sent: vec![],
955 requests_received: vec![IncomingContactRequest {
956 requesting_user_id: user_1,
957 should_notify: false
958 }],
959 },
960 );
961
962 // User can't accept their own contact request
963 db.respond_to_contact_request(user_1, user_2, true)
964 .await
965 .unwrap_err();
966
967 // User accepts a contact request. Both users see the contact.
968 db.respond_to_contact_request(user_2, user_1, true)
969 .await
970 .unwrap();
971 assert_eq!(
972 db.get_contacts(user_1).await.unwrap(),
973 Contacts {
974 current: vec![user_2],
975 requests_sent: vec![],
976 requests_received: vec![],
977 },
978 );
979 assert_eq!(
980 db.get_contacts(user_2).await.unwrap(),
981 Contacts {
982 current: vec![user_1],
983 requests_sent: vec![],
984 requests_received: vec![],
985 },
986 );
987
988 // Users cannot re-request existing contacts.
989 db.send_contact_request(user_1, user_2).await.unwrap_err();
990 db.send_contact_request(user_2, user_1).await.unwrap_err();
991
992 // Users send each other concurrent contact requests and
993 // see that they are immediately accepted.
994 db.send_contact_request(user_1, user_3).await.unwrap();
995 db.send_contact_request(user_3, user_1).await.unwrap();
996 assert_eq!(
997 db.get_contacts(user_1).await.unwrap(),
998 Contacts {
999 current: vec![user_2, user_3],
1000 requests_sent: vec![],
1001 requests_received: vec![],
1002 },
1003 );
1004 assert_eq!(
1005 db.get_contacts(user_3).await.unwrap(),
1006 Contacts {
1007 current: vec![user_1],
1008 requests_sent: vec![],
1009 requests_received: vec![],
1010 },
1011 );
1012
1013 // User declines a contact request. Both users see that it is gone.
1014 db.send_contact_request(user_2, user_3).await.unwrap();
1015 db.respond_to_contact_request(user_3, user_2, false)
1016 .await
1017 .unwrap();
1018 assert_eq!(
1019 db.get_contacts(user_2).await.unwrap(),
1020 Contacts {
1021 current: vec![user_1],
1022 requests_sent: vec![],
1023 requests_received: vec![],
1024 },
1025 );
1026 assert_eq!(
1027 db.get_contacts(user_3).await.unwrap(),
1028 Contacts {
1029 current: vec![user_1],
1030 requests_sent: vec![],
1031 requests_received: vec![],
1032 },
1033 );
1034 }
1035 }
1036
1037 pub struct TestDb {
1038 pub db: Option<Arc<dyn Db>>,
1039 pub url: String,
1040 }
1041
1042 impl TestDb {
1043 pub async fn postgres() -> Self {
1044 lazy_static! {
1045 static ref LOCK: Mutex<()> = Mutex::new(());
1046 }
1047
1048 let _guard = LOCK.lock();
1049 let mut rng = StdRng::from_entropy();
1050 let name = format!("zed-test-{}", rng.gen::<u128>());
1051 let url = format!("postgres://postgres@localhost/{}", name);
1052 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1053 Postgres::create_database(&url)
1054 .await
1055 .expect("failed to create test db");
1056 let db = PostgresDb::new(&url, 5).await.unwrap();
1057 let migrator = Migrator::new(migrations_path).await.unwrap();
1058 migrator.run(&db.pool).await.unwrap();
1059 Self {
1060 db: Some(Arc::new(db)),
1061 url,
1062 }
1063 }
1064
1065 pub fn fake(background: Arc<Background>) -> Self {
1066 Self {
1067 db: Some(Arc::new(FakeDb::new(background))),
1068 url: Default::default(),
1069 }
1070 }
1071
1072 pub fn db(&self) -> &Arc<dyn Db> {
1073 self.db.as_ref().unwrap()
1074 }
1075 }
1076
1077 impl Drop for TestDb {
1078 fn drop(&mut self) {
1079 if let Some(db) = self.db.take() {
1080 futures::executor::block_on(db.teardown(&self.url));
1081 }
1082 }
1083 }
1084
1085 pub struct FakeDb {
1086 background: Arc<Background>,
1087 users: Mutex<BTreeMap<UserId, User>>,
1088 next_user_id: Mutex<i32>,
1089 orgs: Mutex<BTreeMap<OrgId, Org>>,
1090 next_org_id: Mutex<i32>,
1091 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1092 channels: Mutex<BTreeMap<ChannelId, Channel>>,
1093 next_channel_id: Mutex<i32>,
1094 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1095 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1096 next_channel_message_id: Mutex<i32>,
1097 contacts: Mutex<Vec<FakeContact>>,
1098 }
1099
1100 struct FakeContact {
1101 requester_id: UserId,
1102 responder_id: UserId,
1103 accepted: bool,
1104 should_notify: bool,
1105 }
1106
1107 impl FakeDb {
1108 pub fn new(background: Arc<Background>) -> Self {
1109 Self {
1110 background,
1111 users: Default::default(),
1112 next_user_id: Mutex::new(1),
1113 orgs: Default::default(),
1114 next_org_id: Mutex::new(1),
1115 org_memberships: Default::default(),
1116 channels: Default::default(),
1117 next_channel_id: Mutex::new(1),
1118 channel_memberships: Default::default(),
1119 channel_messages: Default::default(),
1120 next_channel_message_id: Mutex::new(1),
1121 contacts: Default::default(),
1122 }
1123 }
1124 }
1125
1126 #[async_trait]
1127 impl Db for FakeDb {
1128 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1129 self.background.simulate_random_delay().await;
1130
1131 let mut users = self.users.lock();
1132 if let Some(user) = users
1133 .values()
1134 .find(|user| user.github_login == github_login)
1135 {
1136 Ok(user.id)
1137 } else {
1138 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1139 users.insert(
1140 user_id,
1141 User {
1142 id: user_id,
1143 github_login: github_login.to_string(),
1144 admin,
1145 },
1146 );
1147 Ok(user_id)
1148 }
1149 }
1150
1151 async fn get_all_users(&self) -> Result<Vec<User>> {
1152 unimplemented!()
1153 }
1154
1155 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1156 unimplemented!()
1157 }
1158
1159 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1160 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1161 }
1162
1163 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1164 self.background.simulate_random_delay().await;
1165 let users = self.users.lock();
1166 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1167 }
1168
1169 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
1170 unimplemented!()
1171 }
1172
1173 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1174 unimplemented!()
1175 }
1176
1177 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1178 unimplemented!()
1179 }
1180
1181 async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
1182 self.background.simulate_random_delay().await;
1183 let mut current = Vec::new();
1184 let mut requests_sent = Vec::new();
1185 let mut requests_received = Vec::new();
1186 for contact in self.contacts.lock().iter() {
1187 if contact.requester_id == id {
1188 if contact.accepted {
1189 current.push(contact.responder_id);
1190 } else {
1191 requests_sent.push(contact.responder_id);
1192 }
1193 } else if contact.responder_id == id {
1194 if contact.accepted {
1195 current.push(contact.requester_id);
1196 } else {
1197 requests_received.push(IncomingContactRequest {
1198 requesting_user_id: contact.requester_id,
1199 should_notify: contact.should_notify,
1200 });
1201 }
1202 }
1203 }
1204 Ok(Contacts {
1205 current,
1206 requests_sent,
1207 requests_received,
1208 })
1209 }
1210
1211 async fn send_contact_request(
1212 &self,
1213 requester_id: UserId,
1214 responder_id: UserId,
1215 ) -> Result<()> {
1216 let mut contacts = self.contacts.lock();
1217 for contact in contacts.iter_mut() {
1218 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1219 if contact.accepted {
1220 Err(anyhow!("contact already exists"))?;
1221 } else {
1222 Err(anyhow!("contact already requested"))?;
1223 }
1224 }
1225 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1226 if contact.accepted {
1227 Err(anyhow!("contact already exists"))?;
1228 } else {
1229 contact.accepted = true;
1230 return Ok(());
1231 }
1232 }
1233 }
1234 contacts.push(FakeContact {
1235 requester_id,
1236 responder_id,
1237 accepted: false,
1238 should_notify: true,
1239 });
1240 Ok(())
1241 }
1242
1243 async fn dismiss_contact_request(
1244 &self,
1245 responder_id: UserId,
1246 requester_id: UserId,
1247 ) -> Result<()> {
1248 let mut contacts = self.contacts.lock();
1249 for contact in contacts.iter_mut() {
1250 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1251 if contact.accepted {
1252 return Err(anyhow!("contact already confirmed"));
1253 }
1254 contact.should_notify = false;
1255 return Ok(());
1256 }
1257 }
1258 Err(anyhow!("no such contact request"))
1259 }
1260
1261 async fn respond_to_contact_request(
1262 &self,
1263 responder_id: UserId,
1264 requester_id: UserId,
1265 accept: bool,
1266 ) -> Result<()> {
1267 let mut contacts = self.contacts.lock();
1268 for (ix, contact) in contacts.iter_mut().enumerate() {
1269 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1270 if contact.accepted {
1271 return Err(anyhow!("contact already confirmed"));
1272 }
1273 if accept {
1274 contact.accepted = true;
1275 } else {
1276 contacts.remove(ix);
1277 }
1278 return Ok(());
1279 }
1280 }
1281 Err(anyhow!("no such contact request"))
1282 }
1283
1284 async fn create_access_token_hash(
1285 &self,
1286 _user_id: UserId,
1287 _access_token_hash: &str,
1288 _max_access_token_count: usize,
1289 ) -> Result<()> {
1290 unimplemented!()
1291 }
1292
1293 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1294 unimplemented!()
1295 }
1296
1297 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1298 unimplemented!()
1299 }
1300
1301 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1302 self.background.simulate_random_delay().await;
1303 let mut orgs = self.orgs.lock();
1304 if orgs.values().any(|org| org.slug == slug) {
1305 Err(anyhow!("org already exists"))
1306 } else {
1307 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1308 orgs.insert(
1309 org_id,
1310 Org {
1311 id: org_id,
1312 name: name.to_string(),
1313 slug: slug.to_string(),
1314 },
1315 );
1316 Ok(org_id)
1317 }
1318 }
1319
1320 async fn add_org_member(
1321 &self,
1322 org_id: OrgId,
1323 user_id: UserId,
1324 is_admin: bool,
1325 ) -> Result<()> {
1326 self.background.simulate_random_delay().await;
1327 if !self.orgs.lock().contains_key(&org_id) {
1328 return Err(anyhow!("org does not exist"));
1329 }
1330 if !self.users.lock().contains_key(&user_id) {
1331 return Err(anyhow!("user does not exist"));
1332 }
1333
1334 self.org_memberships
1335 .lock()
1336 .entry((org_id, user_id))
1337 .or_insert(is_admin);
1338 Ok(())
1339 }
1340
1341 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1342 self.background.simulate_random_delay().await;
1343 if !self.orgs.lock().contains_key(&org_id) {
1344 return Err(anyhow!("org does not exist"));
1345 }
1346
1347 let mut channels = self.channels.lock();
1348 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1349 channels.insert(
1350 channel_id,
1351 Channel {
1352 id: channel_id,
1353 name: name.to_string(),
1354 owner_id: org_id.0,
1355 owner_is_user: false,
1356 },
1357 );
1358 Ok(channel_id)
1359 }
1360
1361 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1362 self.background.simulate_random_delay().await;
1363 Ok(self
1364 .channels
1365 .lock()
1366 .values()
1367 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1368 .cloned()
1369 .collect())
1370 }
1371
1372 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1373 self.background.simulate_random_delay().await;
1374 let channels = self.channels.lock();
1375 let memberships = self.channel_memberships.lock();
1376 Ok(channels
1377 .values()
1378 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1379 .cloned()
1380 .collect())
1381 }
1382
1383 async fn can_user_access_channel(
1384 &self,
1385 user_id: UserId,
1386 channel_id: ChannelId,
1387 ) -> Result<bool> {
1388 self.background.simulate_random_delay().await;
1389 Ok(self
1390 .channel_memberships
1391 .lock()
1392 .contains_key(&(channel_id, user_id)))
1393 }
1394
1395 async fn add_channel_member(
1396 &self,
1397 channel_id: ChannelId,
1398 user_id: UserId,
1399 is_admin: bool,
1400 ) -> Result<()> {
1401 self.background.simulate_random_delay().await;
1402 if !self.channels.lock().contains_key(&channel_id) {
1403 return Err(anyhow!("channel does not exist"));
1404 }
1405 if !self.users.lock().contains_key(&user_id) {
1406 return Err(anyhow!("user does not exist"));
1407 }
1408
1409 self.channel_memberships
1410 .lock()
1411 .entry((channel_id, user_id))
1412 .or_insert(is_admin);
1413 Ok(())
1414 }
1415
1416 async fn create_channel_message(
1417 &self,
1418 channel_id: ChannelId,
1419 sender_id: UserId,
1420 body: &str,
1421 timestamp: OffsetDateTime,
1422 nonce: u128,
1423 ) -> Result<MessageId> {
1424 self.background.simulate_random_delay().await;
1425 if !self.channels.lock().contains_key(&channel_id) {
1426 return Err(anyhow!("channel does not exist"));
1427 }
1428 if !self.users.lock().contains_key(&sender_id) {
1429 return Err(anyhow!("user does not exist"));
1430 }
1431
1432 let mut messages = self.channel_messages.lock();
1433 if let Some(message) = messages
1434 .values()
1435 .find(|message| message.nonce.as_u128() == nonce)
1436 {
1437 Ok(message.id)
1438 } else {
1439 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1440 messages.insert(
1441 message_id,
1442 ChannelMessage {
1443 id: message_id,
1444 channel_id,
1445 sender_id,
1446 body: body.to_string(),
1447 sent_at: timestamp,
1448 nonce: Uuid::from_u128(nonce),
1449 },
1450 );
1451 Ok(message_id)
1452 }
1453 }
1454
1455 async fn get_channel_messages(
1456 &self,
1457 channel_id: ChannelId,
1458 count: usize,
1459 before_id: Option<MessageId>,
1460 ) -> Result<Vec<ChannelMessage>> {
1461 let mut messages = self
1462 .channel_messages
1463 .lock()
1464 .values()
1465 .rev()
1466 .filter(|message| {
1467 message.channel_id == channel_id
1468 && message.id < before_id.unwrap_or(MessageId::MAX)
1469 })
1470 .take(count)
1471 .cloned()
1472 .collect::<Vec<_>>();
1473 messages.sort_unstable_by_key(|message| message.id);
1474 Ok(messages)
1475 }
1476
1477 async fn teardown(&self, _: &str) {}
1478 }
1479}