1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19
20 async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
21 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
22 async fn dismiss_contact_request(
23 &self,
24 responder_id: UserId,
25 requester_id: UserId,
26 ) -> Result<()>;
27 async fn respond_to_contact_request(
28 &self,
29 responder_id: UserId,
30 requester_id: UserId,
31 accept: bool,
32 ) -> Result<()>;
33
34 async fn create_access_token_hash(
35 &self,
36 user_id: UserId,
37 access_token_hash: &str,
38 max_access_token_count: usize,
39 ) -> Result<()>;
40 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
41 #[cfg(any(test, feature = "seed-support"))]
42
43 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
44 #[cfg(any(test, feature = "seed-support"))]
45 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
46 #[cfg(any(test, feature = "seed-support"))]
47 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
48 #[cfg(any(test, feature = "seed-support"))]
49 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
50 #[cfg(any(test, feature = "seed-support"))]
51
52 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
53 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
54 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
55 -> Result<bool>;
56 #[cfg(any(test, feature = "seed-support"))]
57 async fn add_channel_member(
58 &self,
59 channel_id: ChannelId,
60 user_id: UserId,
61 is_admin: bool,
62 ) -> Result<()>;
63 async fn create_channel_message(
64 &self,
65 channel_id: ChannelId,
66 sender_id: UserId,
67 body: &str,
68 timestamp: OffsetDateTime,
69 nonce: u128,
70 ) -> Result<MessageId>;
71 async fn get_channel_messages(
72 &self,
73 channel_id: ChannelId,
74 count: usize,
75 before_id: Option<MessageId>,
76 ) -> Result<Vec<ChannelMessage>>;
77 #[cfg(test)]
78 async fn teardown(&self, url: &str);
79}
80
81pub struct PostgresDb {
82 pool: sqlx::PgPool,
83}
84
85impl PostgresDb {
86 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
87 let pool = DbOptions::new()
88 .max_connections(max_connections)
89 .connect(&url)
90 .await
91 .context("failed to connect to postgres database")?;
92 Ok(Self { pool })
93 }
94}
95
96#[async_trait]
97impl Db for PostgresDb {
98 // users
99
100 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101 let query = "
102 INSERT INTO users (github_login, admin)
103 VALUES ($1, $2)
104 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
105 RETURNING id
106 ";
107 Ok(sqlx::query_scalar(query)
108 .bind(github_login)
109 .bind(admin)
110 .fetch_one(&self.pool)
111 .await
112 .map(UserId)?)
113 }
114
115 async fn get_all_users(&self) -> Result<Vec<User>> {
116 let query = "SELECT * FROM users ORDER BY github_login ASC";
117 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
118 }
119
120 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
121 let like_string = fuzzy_like_string(name_query);
122 let query = "
123 SELECT users.*
124 FROM users
125 WHERE github_login like $1
126 ORDER BY github_login <-> $2
127 LIMIT $3
128 ";
129 Ok(sqlx::query_as(query)
130 .bind(like_string)
131 .bind(name_query)
132 .bind(limit)
133 .fetch_all(&self.pool)
134 .await?)
135 }
136
137 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
138 let users = self.get_users_by_ids(vec![id]).await?;
139 Ok(users.into_iter().next())
140 }
141
142 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
143 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
144 let query = "
145 SELECT users.*
146 FROM users
147 WHERE users.id = ANY ($1)
148 ";
149 Ok(sqlx::query_as(query)
150 .bind(&ids)
151 .fetch_all(&self.pool)
152 .await?)
153 }
154
155 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
156 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
157 Ok(sqlx::query_as(query)
158 .bind(github_login)
159 .fetch_optional(&self.pool)
160 .await?)
161 }
162
163 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
164 let query = "UPDATE users SET admin = $1 WHERE id = $2";
165 Ok(sqlx::query(query)
166 .bind(is_admin)
167 .bind(id.0)
168 .execute(&self.pool)
169 .await
170 .map(drop)?)
171 }
172
173 async fn destroy_user(&self, id: UserId) -> Result<()> {
174 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
175 sqlx::query(query)
176 .bind(id.0)
177 .execute(&self.pool)
178 .await
179 .map(drop)?;
180 let query = "DELETE FROM users WHERE id = $1;";
181 Ok(sqlx::query(query)
182 .bind(id.0)
183 .execute(&self.pool)
184 .await
185 .map(drop)?)
186 }
187
188 // contacts
189
190 async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
191 let query = "
192 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
193 FROM contacts
194 WHERE user_id_a = $1 OR user_id_b = $1;
195 ";
196
197 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
198 .bind(user_id)
199 .fetch(&self.pool);
200
201 let mut current = Vec::new();
202 let mut outgoing_requests = Vec::new();
203 let mut incoming_requests = Vec::new();
204 while let Some(row) = rows.next().await {
205 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
206
207 if user_id_a == user_id {
208 if accepted {
209 current.push(user_id_b);
210 } else if a_to_b {
211 outgoing_requests.push(user_id_b);
212 } else {
213 incoming_requests.push(IncomingContactRequest {
214 requester_id: user_id_b,
215 should_notify,
216 });
217 }
218 } else {
219 if accepted {
220 current.push(user_id_a);
221 } else if a_to_b {
222 incoming_requests.push(IncomingContactRequest {
223 requester_id: user_id_a,
224 should_notify,
225 });
226 } else {
227 outgoing_requests.push(user_id_a);
228 }
229 }
230 }
231
232 Ok(Contacts {
233 current,
234 outgoing_requests,
235 incoming_requests,
236 })
237 }
238
239 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
240 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
241 (sender_id, receiver_id, true)
242 } else {
243 (receiver_id, sender_id, false)
244 };
245 let query = "
246 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
247 VALUES ($1, $2, $3, 'f', 't')
248 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
249 SET
250 accepted = 't'
251 WHERE
252 NOT contacts.accepted AND
253 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
254 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
255 ";
256 let result = sqlx::query(query)
257 .bind(id_a.0)
258 .bind(id_b.0)
259 .bind(a_to_b)
260 .execute(&self.pool)
261 .await?;
262
263 if result.rows_affected() == 1 {
264 Ok(())
265 } else {
266 Err(anyhow!("contact already requested"))
267 }
268 }
269
270 async fn respond_to_contact_request(
271 &self,
272 responder_id: UserId,
273 requester_id: UserId,
274 accept: bool,
275 ) -> Result<()> {
276 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
277 (responder_id, requester_id, false)
278 } else {
279 (requester_id, responder_id, true)
280 };
281 let result = if accept {
282 let query = "
283 UPDATE contacts
284 SET accepted = 't', should_notify = 'f'
285 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
286 ";
287 sqlx::query(query)
288 .bind(id_a.0)
289 .bind(id_b.0)
290 .bind(a_to_b)
291 .execute(&self.pool)
292 .await?
293 } else {
294 let query = "
295 DELETE FROM contacts
296 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
297 ";
298 sqlx::query(query)
299 .bind(id_a.0)
300 .bind(id_b.0)
301 .bind(a_to_b)
302 .execute(&self.pool)
303 .await?
304 };
305 if result.rows_affected() == 1 {
306 Ok(())
307 } else {
308 Err(anyhow!("no such contact request"))
309 }
310 }
311
312 async fn dismiss_contact_request(
313 &self,
314 responder_id: UserId,
315 requester_id: UserId,
316 ) -> Result<()> {
317 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
318 (responder_id, requester_id, false)
319 } else {
320 (requester_id, responder_id, true)
321 };
322
323 let query = "
324 UPDATE contacts
325 SET should_notify = 'f'
326 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
327 ";
328
329 let result = sqlx::query(query)
330 .bind(id_a.0)
331 .bind(id_b.0)
332 .bind(a_to_b)
333 .execute(&self.pool)
334 .await?;
335
336 if result.rows_affected() == 0 {
337 Err(anyhow!("no such contact request"))?;
338 }
339
340 Ok(())
341 }
342
343 // access tokens
344
345 async fn create_access_token_hash(
346 &self,
347 user_id: UserId,
348 access_token_hash: &str,
349 max_access_token_count: usize,
350 ) -> Result<()> {
351 let insert_query = "
352 INSERT INTO access_tokens (user_id, hash)
353 VALUES ($1, $2);
354 ";
355 let cleanup_query = "
356 DELETE FROM access_tokens
357 WHERE id IN (
358 SELECT id from access_tokens
359 WHERE user_id = $1
360 ORDER BY id DESC
361 OFFSET $3
362 )
363 ";
364
365 let mut tx = self.pool.begin().await?;
366 sqlx::query(insert_query)
367 .bind(user_id.0)
368 .bind(access_token_hash)
369 .execute(&mut tx)
370 .await?;
371 sqlx::query(cleanup_query)
372 .bind(user_id.0)
373 .bind(access_token_hash)
374 .bind(max_access_token_count as u32)
375 .execute(&mut tx)
376 .await?;
377 Ok(tx.commit().await?)
378 }
379
380 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
381 let query = "
382 SELECT hash
383 FROM access_tokens
384 WHERE user_id = $1
385 ORDER BY id DESC
386 ";
387 Ok(sqlx::query_scalar(query)
388 .bind(user_id.0)
389 .fetch_all(&self.pool)
390 .await?)
391 }
392
393 // orgs
394
395 #[allow(unused)] // Help rust-analyzer
396 #[cfg(any(test, feature = "seed-support"))]
397 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
398 let query = "
399 SELECT *
400 FROM orgs
401 WHERE slug = $1
402 ";
403 Ok(sqlx::query_as(query)
404 .bind(slug)
405 .fetch_optional(&self.pool)
406 .await?)
407 }
408
409 #[cfg(any(test, feature = "seed-support"))]
410 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
411 let query = "
412 INSERT INTO orgs (name, slug)
413 VALUES ($1, $2)
414 RETURNING id
415 ";
416 Ok(sqlx::query_scalar(query)
417 .bind(name)
418 .bind(slug)
419 .fetch_one(&self.pool)
420 .await
421 .map(OrgId)?)
422 }
423
424 #[cfg(any(test, feature = "seed-support"))]
425 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
426 let query = "
427 INSERT INTO org_memberships (org_id, user_id, admin)
428 VALUES ($1, $2, $3)
429 ON CONFLICT DO NOTHING
430 ";
431 Ok(sqlx::query(query)
432 .bind(org_id.0)
433 .bind(user_id.0)
434 .bind(is_admin)
435 .execute(&self.pool)
436 .await
437 .map(drop)?)
438 }
439
440 // channels
441
442 #[cfg(any(test, feature = "seed-support"))]
443 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
444 let query = "
445 INSERT INTO channels (owner_id, owner_is_user, name)
446 VALUES ($1, false, $2)
447 RETURNING id
448 ";
449 Ok(sqlx::query_scalar(query)
450 .bind(org_id.0)
451 .bind(name)
452 .fetch_one(&self.pool)
453 .await
454 .map(ChannelId)?)
455 }
456
457 #[allow(unused)] // Help rust-analyzer
458 #[cfg(any(test, feature = "seed-support"))]
459 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
460 let query = "
461 SELECT *
462 FROM channels
463 WHERE
464 channels.owner_is_user = false AND
465 channels.owner_id = $1
466 ";
467 Ok(sqlx::query_as(query)
468 .bind(org_id.0)
469 .fetch_all(&self.pool)
470 .await?)
471 }
472
473 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
474 let query = "
475 SELECT
476 channels.*
477 FROM
478 channel_memberships, channels
479 WHERE
480 channel_memberships.user_id = $1 AND
481 channel_memberships.channel_id = channels.id
482 ";
483 Ok(sqlx::query_as(query)
484 .bind(user_id.0)
485 .fetch_all(&self.pool)
486 .await?)
487 }
488
489 async fn can_user_access_channel(
490 &self,
491 user_id: UserId,
492 channel_id: ChannelId,
493 ) -> Result<bool> {
494 let query = "
495 SELECT id
496 FROM channel_memberships
497 WHERE user_id = $1 AND channel_id = $2
498 LIMIT 1
499 ";
500 Ok(sqlx::query_scalar::<_, i32>(query)
501 .bind(user_id.0)
502 .bind(channel_id.0)
503 .fetch_optional(&self.pool)
504 .await
505 .map(|e| e.is_some())?)
506 }
507
508 #[cfg(any(test, feature = "seed-support"))]
509 async fn add_channel_member(
510 &self,
511 channel_id: ChannelId,
512 user_id: UserId,
513 is_admin: bool,
514 ) -> Result<()> {
515 let query = "
516 INSERT INTO channel_memberships (channel_id, user_id, admin)
517 VALUES ($1, $2, $3)
518 ON CONFLICT DO NOTHING
519 ";
520 Ok(sqlx::query(query)
521 .bind(channel_id.0)
522 .bind(user_id.0)
523 .bind(is_admin)
524 .execute(&self.pool)
525 .await
526 .map(drop)?)
527 }
528
529 // messages
530
531 async fn create_channel_message(
532 &self,
533 channel_id: ChannelId,
534 sender_id: UserId,
535 body: &str,
536 timestamp: OffsetDateTime,
537 nonce: u128,
538 ) -> Result<MessageId> {
539 let query = "
540 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
541 VALUES ($1, $2, $3, $4, $5)
542 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
543 RETURNING id
544 ";
545 Ok(sqlx::query_scalar(query)
546 .bind(channel_id.0)
547 .bind(sender_id.0)
548 .bind(body)
549 .bind(timestamp)
550 .bind(Uuid::from_u128(nonce))
551 .fetch_one(&self.pool)
552 .await
553 .map(MessageId)?)
554 }
555
556 async fn get_channel_messages(
557 &self,
558 channel_id: ChannelId,
559 count: usize,
560 before_id: Option<MessageId>,
561 ) -> Result<Vec<ChannelMessage>> {
562 let query = r#"
563 SELECT * FROM (
564 SELECT
565 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
566 FROM
567 channel_messages
568 WHERE
569 channel_id = $1 AND
570 id < $2
571 ORDER BY id DESC
572 LIMIT $3
573 ) as recent_messages
574 ORDER BY id ASC
575 "#;
576 Ok(sqlx::query_as(query)
577 .bind(channel_id.0)
578 .bind(before_id.unwrap_or(MessageId::MAX))
579 .bind(count as i64)
580 .fetch_all(&self.pool)
581 .await?)
582 }
583
584 #[cfg(test)]
585 async fn teardown(&self, url: &str) {
586 use util::ResultExt;
587
588 let query = "
589 SELECT pg_terminate_backend(pg_stat_activity.pid)
590 FROM pg_stat_activity
591 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
592 ";
593 sqlx::query(query).execute(&self.pool).await.log_err();
594 self.pool.close().await;
595 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
596 .await
597 .log_err();
598 }
599}
600
601macro_rules! id_type {
602 ($name:ident) => {
603 #[derive(
604 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
605 )]
606 #[sqlx(transparent)]
607 #[serde(transparent)]
608 pub struct $name(pub i32);
609
610 impl $name {
611 #[allow(unused)]
612 pub const MAX: Self = Self(i32::MAX);
613
614 #[allow(unused)]
615 pub fn from_proto(value: u64) -> Self {
616 Self(value as i32)
617 }
618
619 #[allow(unused)]
620 pub fn to_proto(&self) -> u64 {
621 self.0 as u64
622 }
623 }
624
625 impl std::fmt::Display for $name {
626 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
627 self.0.fmt(f)
628 }
629 }
630 };
631}
632
633id_type!(UserId);
634#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
635pub struct User {
636 pub id: UserId,
637 pub github_login: String,
638 pub admin: bool,
639}
640
641id_type!(OrgId);
642#[derive(FromRow)]
643pub struct Org {
644 pub id: OrgId,
645 pub name: String,
646 pub slug: String,
647}
648
649id_type!(ChannelId);
650#[derive(Clone, Debug, FromRow, Serialize)]
651pub struct Channel {
652 pub id: ChannelId,
653 pub name: String,
654 pub owner_id: i32,
655 pub owner_is_user: bool,
656}
657
658id_type!(MessageId);
659#[derive(Clone, Debug, FromRow)]
660pub struct ChannelMessage {
661 pub id: MessageId,
662 pub channel_id: ChannelId,
663 pub sender_id: UserId,
664 pub body: String,
665 pub sent_at: OffsetDateTime,
666 pub nonce: Uuid,
667}
668
669#[derive(Clone, Debug, PartialEq, Eq)]
670pub struct Contacts {
671 pub current: Vec<UserId>,
672 pub incoming_requests: Vec<IncomingContactRequest>,
673 pub outgoing_requests: Vec<UserId>,
674}
675
676#[derive(Clone, Debug, PartialEq, Eq)]
677pub struct IncomingContactRequest {
678 pub requester_id: UserId,
679 pub should_notify: bool,
680}
681
682fn fuzzy_like_string(string: &str) -> String {
683 let mut result = String::with_capacity(string.len() * 2 + 1);
684 for c in string.chars() {
685 if c.is_alphanumeric() {
686 result.push('%');
687 result.push(c);
688 }
689 }
690 result.push('%');
691 result
692}
693
694#[cfg(test)]
695pub mod tests {
696 use super::*;
697 use anyhow::anyhow;
698 use collections::BTreeMap;
699 use gpui::executor::Background;
700 use lazy_static::lazy_static;
701 use parking_lot::Mutex;
702 use rand::prelude::*;
703 use sqlx::{
704 migrate::{MigrateDatabase, Migrator},
705 Postgres,
706 };
707 use std::{path::Path, sync::Arc};
708 use util::post_inc;
709
710 #[tokio::test(flavor = "multi_thread")]
711 async fn test_get_users_by_ids() {
712 for test_db in [
713 TestDb::postgres().await,
714 TestDb::fake(Arc::new(gpui::executor::Background::new())),
715 ] {
716 let db = test_db.db();
717
718 let user = db.create_user("user", false).await.unwrap();
719 let friend1 = db.create_user("friend-1", false).await.unwrap();
720 let friend2 = db.create_user("friend-2", false).await.unwrap();
721 let friend3 = db.create_user("friend-3", false).await.unwrap();
722
723 assert_eq!(
724 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
725 .await
726 .unwrap(),
727 vec![
728 User {
729 id: user,
730 github_login: "user".to_string(),
731 admin: false,
732 },
733 User {
734 id: friend1,
735 github_login: "friend-1".to_string(),
736 admin: false,
737 },
738 User {
739 id: friend2,
740 github_login: "friend-2".to_string(),
741 admin: false,
742 },
743 User {
744 id: friend3,
745 github_login: "friend-3".to_string(),
746 admin: false,
747 }
748 ]
749 );
750 }
751 }
752
753 #[tokio::test(flavor = "multi_thread")]
754 async fn test_recent_channel_messages() {
755 for test_db in [
756 TestDb::postgres().await,
757 TestDb::fake(Arc::new(gpui::executor::Background::new())),
758 ] {
759 let db = test_db.db();
760 let user = db.create_user("user", false).await.unwrap();
761 let org = db.create_org("org", "org").await.unwrap();
762 let channel = db.create_org_channel(org, "channel").await.unwrap();
763 for i in 0..10 {
764 db.create_channel_message(
765 channel,
766 user,
767 &i.to_string(),
768 OffsetDateTime::now_utc(),
769 i,
770 )
771 .await
772 .unwrap();
773 }
774
775 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
776 assert_eq!(
777 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
778 ["5", "6", "7", "8", "9"]
779 );
780
781 let prev_messages = db
782 .get_channel_messages(channel, 4, Some(messages[0].id))
783 .await
784 .unwrap();
785 assert_eq!(
786 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
787 ["1", "2", "3", "4"]
788 );
789 }
790 }
791
792 #[tokio::test(flavor = "multi_thread")]
793 async fn test_channel_message_nonces() {
794 for test_db in [
795 TestDb::postgres().await,
796 TestDb::fake(Arc::new(gpui::executor::Background::new())),
797 ] {
798 let db = test_db.db();
799 let user = db.create_user("user", false).await.unwrap();
800 let org = db.create_org("org", "org").await.unwrap();
801 let channel = db.create_org_channel(org, "channel").await.unwrap();
802
803 let msg1_id = db
804 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
805 .await
806 .unwrap();
807 let msg2_id = db
808 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
809 .await
810 .unwrap();
811 let msg3_id = db
812 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
813 .await
814 .unwrap();
815 let msg4_id = db
816 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
817 .await
818 .unwrap();
819
820 assert_ne!(msg1_id, msg2_id);
821 assert_eq!(msg1_id, msg3_id);
822 assert_eq!(msg2_id, msg4_id);
823 }
824 }
825
826 #[tokio::test(flavor = "multi_thread")]
827 async fn test_create_access_tokens() {
828 let test_db = TestDb::postgres().await;
829 let db = test_db.db();
830 let user = db.create_user("the-user", false).await.unwrap();
831
832 db.create_access_token_hash(user, "h1", 3).await.unwrap();
833 db.create_access_token_hash(user, "h2", 3).await.unwrap();
834 assert_eq!(
835 db.get_access_token_hashes(user).await.unwrap(),
836 &["h2".to_string(), "h1".to_string()]
837 );
838
839 db.create_access_token_hash(user, "h3", 3).await.unwrap();
840 assert_eq!(
841 db.get_access_token_hashes(user).await.unwrap(),
842 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
843 );
844
845 db.create_access_token_hash(user, "h4", 3).await.unwrap();
846 assert_eq!(
847 db.get_access_token_hashes(user).await.unwrap(),
848 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
849 );
850
851 db.create_access_token_hash(user, "h5", 3).await.unwrap();
852 assert_eq!(
853 db.get_access_token_hashes(user).await.unwrap(),
854 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
855 );
856 }
857
858 #[test]
859 fn test_fuzzy_like_string() {
860 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
861 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
862 assert_eq!(fuzzy_like_string(" z "), "%z%");
863 }
864
865 #[tokio::test(flavor = "multi_thread")]
866 async fn test_fuzzy_search_users() {
867 let test_db = TestDb::postgres().await;
868 let db = test_db.db();
869 for github_login in [
870 "california",
871 "colorado",
872 "oregon",
873 "washington",
874 "florida",
875 "delaware",
876 "rhode-island",
877 ] {
878 db.create_user(github_login, false).await.unwrap();
879 }
880
881 assert_eq!(
882 fuzzy_search_user_names(db, "clr").await,
883 &["colorado", "california"]
884 );
885 assert_eq!(
886 fuzzy_search_user_names(db, "ro").await,
887 &["rhode-island", "colorado", "oregon"],
888 );
889
890 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
891 db.fuzzy_search_users(query, 10)
892 .await
893 .unwrap()
894 .into_iter()
895 .map(|user| user.github_login)
896 .collect::<Vec<_>>()
897 }
898 }
899
900 #[tokio::test(flavor = "multi_thread")]
901 async fn test_add_contacts() {
902 for test_db in [
903 TestDb::postgres().await,
904 TestDb::fake(Arc::new(gpui::executor::Background::new())),
905 ] {
906 let db = test_db.db();
907
908 let user_1 = db.create_user("user1", false).await.unwrap();
909 let user_2 = db.create_user("user2", false).await.unwrap();
910 let user_3 = db.create_user("user3", false).await.unwrap();
911
912 // User starts with no contacts
913 assert_eq!(
914 db.get_contacts(user_1).await.unwrap(),
915 Contacts {
916 current: vec![],
917 outgoing_requests: vec![],
918 incoming_requests: vec![],
919 },
920 );
921
922 // User requests a contact. Both users see the pending request.
923 db.send_contact_request(user_1, user_2).await.unwrap();
924 assert_eq!(
925 db.get_contacts(user_1).await.unwrap(),
926 Contacts {
927 current: vec![],
928 outgoing_requests: vec![user_2],
929 incoming_requests: vec![],
930 },
931 );
932 assert_eq!(
933 db.get_contacts(user_2).await.unwrap(),
934 Contacts {
935 current: vec![],
936 outgoing_requests: vec![],
937 incoming_requests: vec![IncomingContactRequest {
938 requester_id: user_1,
939 should_notify: true
940 }],
941 },
942 );
943
944 // User 2 dismisses the contact request notification without accepting or rejecting.
945 // We shouldn't notify them again.
946 db.dismiss_contact_request(user_1, user_2)
947 .await
948 .unwrap_err();
949 db.dismiss_contact_request(user_2, user_1).await.unwrap();
950 assert_eq!(
951 db.get_contacts(user_2).await.unwrap(),
952 Contacts {
953 current: vec![],
954 outgoing_requests: vec![],
955 incoming_requests: vec![IncomingContactRequest {
956 requester_id: user_1,
957 should_notify: false
958 }],
959 },
960 );
961
962 // User can't accept their own contact request
963 db.respond_to_contact_request(user_1, user_2, true)
964 .await
965 .unwrap_err();
966
967 // User accepts a contact request. Both users see the contact.
968 db.respond_to_contact_request(user_2, user_1, true)
969 .await
970 .unwrap();
971 assert_eq!(
972 db.get_contacts(user_1).await.unwrap(),
973 Contacts {
974 current: vec![user_2],
975 outgoing_requests: vec![],
976 incoming_requests: vec![],
977 },
978 );
979 assert_eq!(
980 db.get_contacts(user_2).await.unwrap(),
981 Contacts {
982 current: vec![user_1],
983 outgoing_requests: vec![],
984 incoming_requests: vec![],
985 },
986 );
987
988 // Users cannot re-request existing contacts.
989 db.send_contact_request(user_1, user_2).await.unwrap_err();
990 db.send_contact_request(user_2, user_1).await.unwrap_err();
991
992 // Users send each other concurrent contact requests and
993 // see that they are immediately accepted.
994 db.send_contact_request(user_1, user_3).await.unwrap();
995 db.send_contact_request(user_3, user_1).await.unwrap();
996 assert_eq!(
997 db.get_contacts(user_1).await.unwrap(),
998 Contacts {
999 current: vec![user_2, user_3],
1000 outgoing_requests: vec![],
1001 incoming_requests: vec![],
1002 },
1003 );
1004 assert_eq!(
1005 db.get_contacts(user_3).await.unwrap(),
1006 Contacts {
1007 current: vec![user_1],
1008 outgoing_requests: vec![],
1009 incoming_requests: vec![],
1010 },
1011 );
1012
1013 // User declines a contact request. Both users see that it is gone.
1014 db.send_contact_request(user_2, user_3).await.unwrap();
1015 db.respond_to_contact_request(user_3, user_2, false)
1016 .await
1017 .unwrap();
1018 assert_eq!(
1019 db.get_contacts(user_2).await.unwrap(),
1020 Contacts {
1021 current: vec![user_1],
1022 outgoing_requests: vec![],
1023 incoming_requests: vec![],
1024 },
1025 );
1026 assert_eq!(
1027 db.get_contacts(user_3).await.unwrap(),
1028 Contacts {
1029 current: vec![user_1],
1030 outgoing_requests: vec![],
1031 incoming_requests: vec![],
1032 },
1033 );
1034 }
1035 }
1036
1037 pub struct TestDb {
1038 pub db: Option<Arc<dyn Db>>,
1039 pub url: String,
1040 }
1041
1042 impl TestDb {
1043 pub async fn postgres() -> Self {
1044 lazy_static! {
1045 static ref LOCK: Mutex<()> = Mutex::new(());
1046 }
1047
1048 let _guard = LOCK.lock();
1049 let mut rng = StdRng::from_entropy();
1050 let name = format!("zed-test-{}", rng.gen::<u128>());
1051 let url = format!("postgres://postgres@localhost/{}", name);
1052 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1053 Postgres::create_database(&url)
1054 .await
1055 .expect("failed to create test db");
1056 let db = PostgresDb::new(&url, 5).await.unwrap();
1057 let migrator = Migrator::new(migrations_path).await.unwrap();
1058 migrator.run(&db.pool).await.unwrap();
1059 Self {
1060 db: Some(Arc::new(db)),
1061 url,
1062 }
1063 }
1064
1065 pub fn fake(background: Arc<Background>) -> Self {
1066 Self {
1067 db: Some(Arc::new(FakeDb::new(background))),
1068 url: Default::default(),
1069 }
1070 }
1071
1072 pub fn db(&self) -> &Arc<dyn Db> {
1073 self.db.as_ref().unwrap()
1074 }
1075 }
1076
1077 impl Drop for TestDb {
1078 fn drop(&mut self) {
1079 if let Some(db) = self.db.take() {
1080 futures::executor::block_on(db.teardown(&self.url));
1081 }
1082 }
1083 }
1084
1085 pub struct FakeDb {
1086 background: Arc<Background>,
1087 users: Mutex<BTreeMap<UserId, User>>,
1088 next_user_id: Mutex<i32>,
1089 orgs: Mutex<BTreeMap<OrgId, Org>>,
1090 next_org_id: Mutex<i32>,
1091 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1092 channels: Mutex<BTreeMap<ChannelId, Channel>>,
1093 next_channel_id: Mutex<i32>,
1094 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1095 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1096 next_channel_message_id: Mutex<i32>,
1097 contacts: Mutex<Vec<FakeContact>>,
1098 }
1099
1100 #[derive(Debug)]
1101 struct FakeContact {
1102 requester_id: UserId,
1103 responder_id: UserId,
1104 accepted: bool,
1105 should_notify: bool,
1106 }
1107
1108 impl FakeDb {
1109 pub fn new(background: Arc<Background>) -> Self {
1110 Self {
1111 background,
1112 users: Default::default(),
1113 next_user_id: Mutex::new(1),
1114 orgs: Default::default(),
1115 next_org_id: Mutex::new(1),
1116 org_memberships: Default::default(),
1117 channels: Default::default(),
1118 next_channel_id: Mutex::new(1),
1119 channel_memberships: Default::default(),
1120 channel_messages: Default::default(),
1121 next_channel_message_id: Mutex::new(1),
1122 contacts: Default::default(),
1123 }
1124 }
1125 }
1126
1127 #[async_trait]
1128 impl Db for FakeDb {
1129 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1130 self.background.simulate_random_delay().await;
1131
1132 let mut users = self.users.lock();
1133 if let Some(user) = users
1134 .values()
1135 .find(|user| user.github_login == github_login)
1136 {
1137 Ok(user.id)
1138 } else {
1139 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1140 users.insert(
1141 user_id,
1142 User {
1143 id: user_id,
1144 github_login: github_login.to_string(),
1145 admin,
1146 },
1147 );
1148 Ok(user_id)
1149 }
1150 }
1151
1152 async fn get_all_users(&self) -> Result<Vec<User>> {
1153 unimplemented!()
1154 }
1155
1156 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1157 unimplemented!()
1158 }
1159
1160 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1161 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1162 }
1163
1164 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1165 self.background.simulate_random_delay().await;
1166 let users = self.users.lock();
1167 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1168 }
1169
1170 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1171 Ok(self
1172 .users
1173 .lock()
1174 .values()
1175 .find(|user| user.github_login == github_login)
1176 .cloned())
1177 }
1178
1179 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1180 unimplemented!()
1181 }
1182
1183 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1184 unimplemented!()
1185 }
1186
1187 async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
1188 self.background.simulate_random_delay().await;
1189 let mut current = Vec::new();
1190 let mut outgoing_requests = Vec::new();
1191 let mut incoming_requests = Vec::new();
1192
1193 for contact in self.contacts.lock().iter() {
1194 if contact.requester_id == id {
1195 if contact.accepted {
1196 current.push(contact.responder_id);
1197 } else {
1198 outgoing_requests.push(contact.responder_id);
1199 }
1200 } else if contact.responder_id == id {
1201 if contact.accepted {
1202 current.push(contact.requester_id);
1203 } else {
1204 incoming_requests.push(IncomingContactRequest {
1205 requester_id: contact.requester_id,
1206 should_notify: contact.should_notify,
1207 });
1208 }
1209 }
1210 }
1211
1212 Ok(Contacts {
1213 current,
1214 outgoing_requests,
1215 incoming_requests,
1216 })
1217 }
1218
1219 async fn send_contact_request(
1220 &self,
1221 requester_id: UserId,
1222 responder_id: UserId,
1223 ) -> Result<()> {
1224 let mut contacts = self.contacts.lock();
1225 for contact in contacts.iter_mut() {
1226 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1227 if contact.accepted {
1228 Err(anyhow!("contact already exists"))?;
1229 } else {
1230 Err(anyhow!("contact already requested"))?;
1231 }
1232 }
1233 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1234 if contact.accepted {
1235 Err(anyhow!("contact already exists"))?;
1236 } else {
1237 contact.accepted = true;
1238 return Ok(());
1239 }
1240 }
1241 }
1242 contacts.push(FakeContact {
1243 requester_id,
1244 responder_id,
1245 accepted: false,
1246 should_notify: true,
1247 });
1248 Ok(())
1249 }
1250
1251 async fn dismiss_contact_request(
1252 &self,
1253 responder_id: UserId,
1254 requester_id: UserId,
1255 ) -> Result<()> {
1256 let mut contacts = self.contacts.lock();
1257 for contact in contacts.iter_mut() {
1258 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1259 if contact.accepted {
1260 return Err(anyhow!("contact already confirmed"));
1261 }
1262 contact.should_notify = false;
1263 return Ok(());
1264 }
1265 }
1266 Err(anyhow!("no such contact request"))
1267 }
1268
1269 async fn respond_to_contact_request(
1270 &self,
1271 responder_id: UserId,
1272 requester_id: UserId,
1273 accept: bool,
1274 ) -> Result<()> {
1275 let mut contacts = self.contacts.lock();
1276 for (ix, contact) in contacts.iter_mut().enumerate() {
1277 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1278 if contact.accepted {
1279 return Err(anyhow!("contact already confirmed"));
1280 }
1281 if accept {
1282 contact.accepted = true;
1283 } else {
1284 contacts.remove(ix);
1285 }
1286 return Ok(());
1287 }
1288 }
1289 Err(anyhow!("no such contact request"))
1290 }
1291
1292 async fn create_access_token_hash(
1293 &self,
1294 _user_id: UserId,
1295 _access_token_hash: &str,
1296 _max_access_token_count: usize,
1297 ) -> Result<()> {
1298 unimplemented!()
1299 }
1300
1301 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1302 unimplemented!()
1303 }
1304
1305 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1306 unimplemented!()
1307 }
1308
1309 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1310 self.background.simulate_random_delay().await;
1311 let mut orgs = self.orgs.lock();
1312 if orgs.values().any(|org| org.slug == slug) {
1313 Err(anyhow!("org already exists"))
1314 } else {
1315 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1316 orgs.insert(
1317 org_id,
1318 Org {
1319 id: org_id,
1320 name: name.to_string(),
1321 slug: slug.to_string(),
1322 },
1323 );
1324 Ok(org_id)
1325 }
1326 }
1327
1328 async fn add_org_member(
1329 &self,
1330 org_id: OrgId,
1331 user_id: UserId,
1332 is_admin: bool,
1333 ) -> Result<()> {
1334 self.background.simulate_random_delay().await;
1335 if !self.orgs.lock().contains_key(&org_id) {
1336 return Err(anyhow!("org does not exist"));
1337 }
1338 if !self.users.lock().contains_key(&user_id) {
1339 return Err(anyhow!("user does not exist"));
1340 }
1341
1342 self.org_memberships
1343 .lock()
1344 .entry((org_id, user_id))
1345 .or_insert(is_admin);
1346 Ok(())
1347 }
1348
1349 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1350 self.background.simulate_random_delay().await;
1351 if !self.orgs.lock().contains_key(&org_id) {
1352 return Err(anyhow!("org does not exist"));
1353 }
1354
1355 let mut channels = self.channels.lock();
1356 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1357 channels.insert(
1358 channel_id,
1359 Channel {
1360 id: channel_id,
1361 name: name.to_string(),
1362 owner_id: org_id.0,
1363 owner_is_user: false,
1364 },
1365 );
1366 Ok(channel_id)
1367 }
1368
1369 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1370 self.background.simulate_random_delay().await;
1371 Ok(self
1372 .channels
1373 .lock()
1374 .values()
1375 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1376 .cloned()
1377 .collect())
1378 }
1379
1380 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1381 self.background.simulate_random_delay().await;
1382 let channels = self.channels.lock();
1383 let memberships = self.channel_memberships.lock();
1384 Ok(channels
1385 .values()
1386 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1387 .cloned()
1388 .collect())
1389 }
1390
1391 async fn can_user_access_channel(
1392 &self,
1393 user_id: UserId,
1394 channel_id: ChannelId,
1395 ) -> Result<bool> {
1396 self.background.simulate_random_delay().await;
1397 Ok(self
1398 .channel_memberships
1399 .lock()
1400 .contains_key(&(channel_id, user_id)))
1401 }
1402
1403 async fn add_channel_member(
1404 &self,
1405 channel_id: ChannelId,
1406 user_id: UserId,
1407 is_admin: bool,
1408 ) -> Result<()> {
1409 self.background.simulate_random_delay().await;
1410 if !self.channels.lock().contains_key(&channel_id) {
1411 return Err(anyhow!("channel does not exist"));
1412 }
1413 if !self.users.lock().contains_key(&user_id) {
1414 return Err(anyhow!("user does not exist"));
1415 }
1416
1417 self.channel_memberships
1418 .lock()
1419 .entry((channel_id, user_id))
1420 .or_insert(is_admin);
1421 Ok(())
1422 }
1423
1424 async fn create_channel_message(
1425 &self,
1426 channel_id: ChannelId,
1427 sender_id: UserId,
1428 body: &str,
1429 timestamp: OffsetDateTime,
1430 nonce: u128,
1431 ) -> Result<MessageId> {
1432 self.background.simulate_random_delay().await;
1433 if !self.channels.lock().contains_key(&channel_id) {
1434 return Err(anyhow!("channel does not exist"));
1435 }
1436 if !self.users.lock().contains_key(&sender_id) {
1437 return Err(anyhow!("user does not exist"));
1438 }
1439
1440 let mut messages = self.channel_messages.lock();
1441 if let Some(message) = messages
1442 .values()
1443 .find(|message| message.nonce.as_u128() == nonce)
1444 {
1445 Ok(message.id)
1446 } else {
1447 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1448 messages.insert(
1449 message_id,
1450 ChannelMessage {
1451 id: message_id,
1452 channel_id,
1453 sender_id,
1454 body: body.to_string(),
1455 sent_at: timestamp,
1456 nonce: Uuid::from_u128(nonce),
1457 },
1458 );
1459 Ok(message_id)
1460 }
1461 }
1462
1463 async fn get_channel_messages(
1464 &self,
1465 channel_id: ChannelId,
1466 count: usize,
1467 before_id: Option<MessageId>,
1468 ) -> Result<Vec<ChannelMessage>> {
1469 let mut messages = self
1470 .channel_messages
1471 .lock()
1472 .values()
1473 .rev()
1474 .filter(|message| {
1475 message.channel_id == channel_id
1476 && message.id < before_id.unwrap_or(MessageId::MAX)
1477 })
1478 .take(count)
1479 .cloned()
1480 .collect::<Vec<_>>();
1481 messages.sort_unstable_by_key(|message| message.id);
1482 Ok(messages)
1483 }
1484
1485 async fn teardown(&self, _: &str) {}
1486 }
1487}