1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19
20 async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
21 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
22 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
23 async fn dismiss_contact_request(
24 &self,
25 responder_id: UserId,
26 requester_id: UserId,
27 ) -> Result<()>;
28 async fn respond_to_contact_request(
29 &self,
30 responder_id: UserId,
31 requester_id: UserId,
32 accept: bool,
33 ) -> Result<()>;
34
35 async fn create_access_token_hash(
36 &self,
37 user_id: UserId,
38 access_token_hash: &str,
39 max_access_token_count: usize,
40 ) -> Result<()>;
41 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
42 #[cfg(any(test, feature = "seed-support"))]
43
44 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
45 #[cfg(any(test, feature = "seed-support"))]
46 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
47 #[cfg(any(test, feature = "seed-support"))]
48 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
49 #[cfg(any(test, feature = "seed-support"))]
50 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
51 #[cfg(any(test, feature = "seed-support"))]
52
53 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
54 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
55 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
56 -> Result<bool>;
57 #[cfg(any(test, feature = "seed-support"))]
58 async fn add_channel_member(
59 &self,
60 channel_id: ChannelId,
61 user_id: UserId,
62 is_admin: bool,
63 ) -> Result<()>;
64 async fn create_channel_message(
65 &self,
66 channel_id: ChannelId,
67 sender_id: UserId,
68 body: &str,
69 timestamp: OffsetDateTime,
70 nonce: u128,
71 ) -> Result<MessageId>;
72 async fn get_channel_messages(
73 &self,
74 channel_id: ChannelId,
75 count: usize,
76 before_id: Option<MessageId>,
77 ) -> Result<Vec<ChannelMessage>>;
78 #[cfg(test)]
79 async fn teardown(&self, url: &str);
80}
81
82pub struct PostgresDb {
83 pool: sqlx::PgPool,
84}
85
86impl PostgresDb {
87 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
88 let pool = DbOptions::new()
89 .max_connections(max_connections)
90 .connect(&url)
91 .await
92 .context("failed to connect to postgres database")?;
93 Ok(Self { pool })
94 }
95}
96
97#[async_trait]
98impl Db for PostgresDb {
99 // users
100
101 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
102 let query = "
103 INSERT INTO users (github_login, admin)
104 VALUES ($1, $2)
105 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
106 RETURNING id
107 ";
108 Ok(sqlx::query_scalar(query)
109 .bind(github_login)
110 .bind(admin)
111 .fetch_one(&self.pool)
112 .await
113 .map(UserId)?)
114 }
115
116 async fn get_all_users(&self) -> Result<Vec<User>> {
117 let query = "SELECT * FROM users ORDER BY github_login ASC";
118 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
119 }
120
121 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
122 let like_string = fuzzy_like_string(name_query);
123 let query = "
124 SELECT users.*
125 FROM users
126 WHERE github_login like $1
127 ORDER BY github_login <-> $2
128 LIMIT $3
129 ";
130 Ok(sqlx::query_as(query)
131 .bind(like_string)
132 .bind(name_query)
133 .bind(limit)
134 .fetch_all(&self.pool)
135 .await?)
136 }
137
138 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
139 let users = self.get_users_by_ids(vec![id]).await?;
140 Ok(users.into_iter().next())
141 }
142
143 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
144 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
145 let query = "
146 SELECT users.*
147 FROM users
148 WHERE users.id = ANY ($1)
149 ";
150 Ok(sqlx::query_as(query)
151 .bind(&ids)
152 .fetch_all(&self.pool)
153 .await?)
154 }
155
156 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
157 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
158 Ok(sqlx::query_as(query)
159 .bind(github_login)
160 .fetch_optional(&self.pool)
161 .await?)
162 }
163
164 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
165 let query = "UPDATE users SET admin = $1 WHERE id = $2";
166 Ok(sqlx::query(query)
167 .bind(is_admin)
168 .bind(id.0)
169 .execute(&self.pool)
170 .await
171 .map(drop)?)
172 }
173
174 async fn destroy_user(&self, id: UserId) -> Result<()> {
175 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
176 sqlx::query(query)
177 .bind(id.0)
178 .execute(&self.pool)
179 .await
180 .map(drop)?;
181 let query = "DELETE FROM users WHERE id = $1;";
182 Ok(sqlx::query(query)
183 .bind(id.0)
184 .execute(&self.pool)
185 .await
186 .map(drop)?)
187 }
188
189 // contacts
190
191 async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
192 let query = "
193 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
194 FROM contacts
195 WHERE user_id_a = $1 OR user_id_b = $1;
196 ";
197
198 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
199 .bind(user_id)
200 .fetch(&self.pool);
201
202 let mut current = Vec::new();
203 let mut outgoing_requests = Vec::new();
204 let mut incoming_requests = Vec::new();
205 while let Some(row) = rows.next().await {
206 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
207
208 if user_id_a == user_id {
209 if accepted {
210 current.push(user_id_b);
211 } else if a_to_b {
212 outgoing_requests.push(user_id_b);
213 } else {
214 incoming_requests.push(IncomingContactRequest {
215 requester_id: user_id_b,
216 should_notify,
217 });
218 }
219 } else {
220 if accepted {
221 current.push(user_id_a);
222 } else if a_to_b {
223 incoming_requests.push(IncomingContactRequest {
224 requester_id: user_id_a,
225 should_notify,
226 });
227 } else {
228 outgoing_requests.push(user_id_a);
229 }
230 }
231 }
232
233 Ok(Contacts {
234 current,
235 outgoing_requests,
236 incoming_requests,
237 })
238 }
239
240 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
241 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
242 (sender_id, receiver_id, true)
243 } else {
244 (receiver_id, sender_id, false)
245 };
246 let query = "
247 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
248 VALUES ($1, $2, $3, 'f', 't')
249 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
250 SET
251 accepted = 't'
252 WHERE
253 NOT contacts.accepted AND
254 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
255 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
256 ";
257 let result = sqlx::query(query)
258 .bind(id_a.0)
259 .bind(id_b.0)
260 .bind(a_to_b)
261 .execute(&self.pool)
262 .await?;
263
264 if result.rows_affected() == 1 {
265 Ok(())
266 } else {
267 Err(anyhow!("contact already requested"))
268 }
269 }
270
271 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
272 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
273 (responder_id, requester_id, false)
274 } else {
275 (requester_id, responder_id, true)
276 };
277 let query = "
278 DELETE FROM contacts
279 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
280 ";
281 let result = sqlx::query(query)
282 .bind(id_a.0)
283 .bind(id_b.0)
284 .bind(a_to_b)
285 .execute(&self.pool)
286 .await?;
287
288 if result.rows_affected() == 1 {
289 Ok(())
290 } else {
291 Err(anyhow!("no such contact"))
292 }
293 }
294
295 async fn respond_to_contact_request(
296 &self,
297 responder_id: UserId,
298 requester_id: UserId,
299 accept: bool,
300 ) -> Result<()> {
301 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
302 (responder_id, requester_id, false)
303 } else {
304 (requester_id, responder_id, true)
305 };
306 let result = if accept {
307 let query = "
308 UPDATE contacts
309 SET accepted = 't', should_notify = 'f'
310 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
311 ";
312 sqlx::query(query)
313 .bind(id_a.0)
314 .bind(id_b.0)
315 .bind(a_to_b)
316 .execute(&self.pool)
317 .await?
318 } else {
319 let query = "
320 DELETE FROM contacts
321 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
322 ";
323 sqlx::query(query)
324 .bind(id_a.0)
325 .bind(id_b.0)
326 .bind(a_to_b)
327 .execute(&self.pool)
328 .await?
329 };
330 if result.rows_affected() == 1 {
331 Ok(())
332 } else {
333 Err(anyhow!("no such contact request"))
334 }
335 }
336
337 async fn dismiss_contact_request(
338 &self,
339 responder_id: UserId,
340 requester_id: UserId,
341 ) -> Result<()> {
342 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
343 (responder_id, requester_id, false)
344 } else {
345 (requester_id, responder_id, true)
346 };
347
348 let query = "
349 UPDATE contacts
350 SET should_notify = 'f'
351 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
352 ";
353
354 let result = sqlx::query(query)
355 .bind(id_a.0)
356 .bind(id_b.0)
357 .bind(a_to_b)
358 .execute(&self.pool)
359 .await?;
360
361 if result.rows_affected() == 0 {
362 Err(anyhow!("no such contact request"))?;
363 }
364
365 Ok(())
366 }
367
368 // access tokens
369
370 async fn create_access_token_hash(
371 &self,
372 user_id: UserId,
373 access_token_hash: &str,
374 max_access_token_count: usize,
375 ) -> Result<()> {
376 let insert_query = "
377 INSERT INTO access_tokens (user_id, hash)
378 VALUES ($1, $2);
379 ";
380 let cleanup_query = "
381 DELETE FROM access_tokens
382 WHERE id IN (
383 SELECT id from access_tokens
384 WHERE user_id = $1
385 ORDER BY id DESC
386 OFFSET $3
387 )
388 ";
389
390 let mut tx = self.pool.begin().await?;
391 sqlx::query(insert_query)
392 .bind(user_id.0)
393 .bind(access_token_hash)
394 .execute(&mut tx)
395 .await?;
396 sqlx::query(cleanup_query)
397 .bind(user_id.0)
398 .bind(access_token_hash)
399 .bind(max_access_token_count as u32)
400 .execute(&mut tx)
401 .await?;
402 Ok(tx.commit().await?)
403 }
404
405 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
406 let query = "
407 SELECT hash
408 FROM access_tokens
409 WHERE user_id = $1
410 ORDER BY id DESC
411 ";
412 Ok(sqlx::query_scalar(query)
413 .bind(user_id.0)
414 .fetch_all(&self.pool)
415 .await?)
416 }
417
418 // orgs
419
420 #[allow(unused)] // Help rust-analyzer
421 #[cfg(any(test, feature = "seed-support"))]
422 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
423 let query = "
424 SELECT *
425 FROM orgs
426 WHERE slug = $1
427 ";
428 Ok(sqlx::query_as(query)
429 .bind(slug)
430 .fetch_optional(&self.pool)
431 .await?)
432 }
433
434 #[cfg(any(test, feature = "seed-support"))]
435 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
436 let query = "
437 INSERT INTO orgs (name, slug)
438 VALUES ($1, $2)
439 RETURNING id
440 ";
441 Ok(sqlx::query_scalar(query)
442 .bind(name)
443 .bind(slug)
444 .fetch_one(&self.pool)
445 .await
446 .map(OrgId)?)
447 }
448
449 #[cfg(any(test, feature = "seed-support"))]
450 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
451 let query = "
452 INSERT INTO org_memberships (org_id, user_id, admin)
453 VALUES ($1, $2, $3)
454 ON CONFLICT DO NOTHING
455 ";
456 Ok(sqlx::query(query)
457 .bind(org_id.0)
458 .bind(user_id.0)
459 .bind(is_admin)
460 .execute(&self.pool)
461 .await
462 .map(drop)?)
463 }
464
465 // channels
466
467 #[cfg(any(test, feature = "seed-support"))]
468 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
469 let query = "
470 INSERT INTO channels (owner_id, owner_is_user, name)
471 VALUES ($1, false, $2)
472 RETURNING id
473 ";
474 Ok(sqlx::query_scalar(query)
475 .bind(org_id.0)
476 .bind(name)
477 .fetch_one(&self.pool)
478 .await
479 .map(ChannelId)?)
480 }
481
482 #[allow(unused)] // Help rust-analyzer
483 #[cfg(any(test, feature = "seed-support"))]
484 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
485 let query = "
486 SELECT *
487 FROM channels
488 WHERE
489 channels.owner_is_user = false AND
490 channels.owner_id = $1
491 ";
492 Ok(sqlx::query_as(query)
493 .bind(org_id.0)
494 .fetch_all(&self.pool)
495 .await?)
496 }
497
498 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
499 let query = "
500 SELECT
501 channels.*
502 FROM
503 channel_memberships, channels
504 WHERE
505 channel_memberships.user_id = $1 AND
506 channel_memberships.channel_id = channels.id
507 ";
508 Ok(sqlx::query_as(query)
509 .bind(user_id.0)
510 .fetch_all(&self.pool)
511 .await?)
512 }
513
514 async fn can_user_access_channel(
515 &self,
516 user_id: UserId,
517 channel_id: ChannelId,
518 ) -> Result<bool> {
519 let query = "
520 SELECT id
521 FROM channel_memberships
522 WHERE user_id = $1 AND channel_id = $2
523 LIMIT 1
524 ";
525 Ok(sqlx::query_scalar::<_, i32>(query)
526 .bind(user_id.0)
527 .bind(channel_id.0)
528 .fetch_optional(&self.pool)
529 .await
530 .map(|e| e.is_some())?)
531 }
532
533 #[cfg(any(test, feature = "seed-support"))]
534 async fn add_channel_member(
535 &self,
536 channel_id: ChannelId,
537 user_id: UserId,
538 is_admin: bool,
539 ) -> Result<()> {
540 let query = "
541 INSERT INTO channel_memberships (channel_id, user_id, admin)
542 VALUES ($1, $2, $3)
543 ON CONFLICT DO NOTHING
544 ";
545 Ok(sqlx::query(query)
546 .bind(channel_id.0)
547 .bind(user_id.0)
548 .bind(is_admin)
549 .execute(&self.pool)
550 .await
551 .map(drop)?)
552 }
553
554 // messages
555
556 async fn create_channel_message(
557 &self,
558 channel_id: ChannelId,
559 sender_id: UserId,
560 body: &str,
561 timestamp: OffsetDateTime,
562 nonce: u128,
563 ) -> Result<MessageId> {
564 let query = "
565 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
566 VALUES ($1, $2, $3, $4, $5)
567 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
568 RETURNING id
569 ";
570 Ok(sqlx::query_scalar(query)
571 .bind(channel_id.0)
572 .bind(sender_id.0)
573 .bind(body)
574 .bind(timestamp)
575 .bind(Uuid::from_u128(nonce))
576 .fetch_one(&self.pool)
577 .await
578 .map(MessageId)?)
579 }
580
581 async fn get_channel_messages(
582 &self,
583 channel_id: ChannelId,
584 count: usize,
585 before_id: Option<MessageId>,
586 ) -> Result<Vec<ChannelMessage>> {
587 let query = r#"
588 SELECT * FROM (
589 SELECT
590 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
591 FROM
592 channel_messages
593 WHERE
594 channel_id = $1 AND
595 id < $2
596 ORDER BY id DESC
597 LIMIT $3
598 ) as recent_messages
599 ORDER BY id ASC
600 "#;
601 Ok(sqlx::query_as(query)
602 .bind(channel_id.0)
603 .bind(before_id.unwrap_or(MessageId::MAX))
604 .bind(count as i64)
605 .fetch_all(&self.pool)
606 .await?)
607 }
608
609 #[cfg(test)]
610 async fn teardown(&self, url: &str) {
611 use util::ResultExt;
612
613 let query = "
614 SELECT pg_terminate_backend(pg_stat_activity.pid)
615 FROM pg_stat_activity
616 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
617 ";
618 sqlx::query(query).execute(&self.pool).await.log_err();
619 self.pool.close().await;
620 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
621 .await
622 .log_err();
623 }
624}
625
626macro_rules! id_type {
627 ($name:ident) => {
628 #[derive(
629 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
630 )]
631 #[sqlx(transparent)]
632 #[serde(transparent)]
633 pub struct $name(pub i32);
634
635 impl $name {
636 #[allow(unused)]
637 pub const MAX: Self = Self(i32::MAX);
638
639 #[allow(unused)]
640 pub fn from_proto(value: u64) -> Self {
641 Self(value as i32)
642 }
643
644 #[allow(unused)]
645 pub fn to_proto(&self) -> u64 {
646 self.0 as u64
647 }
648 }
649
650 impl std::fmt::Display for $name {
651 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
652 self.0.fmt(f)
653 }
654 }
655 };
656}
657
658id_type!(UserId);
659#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
660pub struct User {
661 pub id: UserId,
662 pub github_login: String,
663 pub admin: bool,
664}
665
666id_type!(OrgId);
667#[derive(FromRow)]
668pub struct Org {
669 pub id: OrgId,
670 pub name: String,
671 pub slug: String,
672}
673
674id_type!(ChannelId);
675#[derive(Clone, Debug, FromRow, Serialize)]
676pub struct Channel {
677 pub id: ChannelId,
678 pub name: String,
679 pub owner_id: i32,
680 pub owner_is_user: bool,
681}
682
683id_type!(MessageId);
684#[derive(Clone, Debug, FromRow)]
685pub struct ChannelMessage {
686 pub id: MessageId,
687 pub channel_id: ChannelId,
688 pub sender_id: UserId,
689 pub body: String,
690 pub sent_at: OffsetDateTime,
691 pub nonce: Uuid,
692}
693
694#[derive(Clone, Debug, PartialEq, Eq)]
695pub struct Contacts {
696 pub current: Vec<UserId>,
697 pub incoming_requests: Vec<IncomingContactRequest>,
698 pub outgoing_requests: Vec<UserId>,
699}
700
701#[derive(Clone, Debug, PartialEq, Eq)]
702pub struct IncomingContactRequest {
703 pub requester_id: UserId,
704 pub should_notify: bool,
705}
706
707fn fuzzy_like_string(string: &str) -> String {
708 let mut result = String::with_capacity(string.len() * 2 + 1);
709 for c in string.chars() {
710 if c.is_alphanumeric() {
711 result.push('%');
712 result.push(c);
713 }
714 }
715 result.push('%');
716 result
717}
718
719#[cfg(test)]
720pub mod tests {
721 use super::*;
722 use anyhow::anyhow;
723 use collections::BTreeMap;
724 use gpui::executor::Background;
725 use lazy_static::lazy_static;
726 use parking_lot::Mutex;
727 use rand::prelude::*;
728 use sqlx::{
729 migrate::{MigrateDatabase, Migrator},
730 Postgres,
731 };
732 use std::{path::Path, sync::Arc};
733 use util::post_inc;
734
735 #[tokio::test(flavor = "multi_thread")]
736 async fn test_get_users_by_ids() {
737 for test_db in [
738 TestDb::postgres().await,
739 TestDb::fake(Arc::new(gpui::executor::Background::new())),
740 ] {
741 let db = test_db.db();
742
743 let user = db.create_user("user", false).await.unwrap();
744 let friend1 = db.create_user("friend-1", false).await.unwrap();
745 let friend2 = db.create_user("friend-2", false).await.unwrap();
746 let friend3 = db.create_user("friend-3", false).await.unwrap();
747
748 assert_eq!(
749 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
750 .await
751 .unwrap(),
752 vec![
753 User {
754 id: user,
755 github_login: "user".to_string(),
756 admin: false,
757 },
758 User {
759 id: friend1,
760 github_login: "friend-1".to_string(),
761 admin: false,
762 },
763 User {
764 id: friend2,
765 github_login: "friend-2".to_string(),
766 admin: false,
767 },
768 User {
769 id: friend3,
770 github_login: "friend-3".to_string(),
771 admin: false,
772 }
773 ]
774 );
775 }
776 }
777
778 #[tokio::test(flavor = "multi_thread")]
779 async fn test_recent_channel_messages() {
780 for test_db in [
781 TestDb::postgres().await,
782 TestDb::fake(Arc::new(gpui::executor::Background::new())),
783 ] {
784 let db = test_db.db();
785 let user = db.create_user("user", false).await.unwrap();
786 let org = db.create_org("org", "org").await.unwrap();
787 let channel = db.create_org_channel(org, "channel").await.unwrap();
788 for i in 0..10 {
789 db.create_channel_message(
790 channel,
791 user,
792 &i.to_string(),
793 OffsetDateTime::now_utc(),
794 i,
795 )
796 .await
797 .unwrap();
798 }
799
800 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
801 assert_eq!(
802 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
803 ["5", "6", "7", "8", "9"]
804 );
805
806 let prev_messages = db
807 .get_channel_messages(channel, 4, Some(messages[0].id))
808 .await
809 .unwrap();
810 assert_eq!(
811 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
812 ["1", "2", "3", "4"]
813 );
814 }
815 }
816
817 #[tokio::test(flavor = "multi_thread")]
818 async fn test_channel_message_nonces() {
819 for test_db in [
820 TestDb::postgres().await,
821 TestDb::fake(Arc::new(gpui::executor::Background::new())),
822 ] {
823 let db = test_db.db();
824 let user = db.create_user("user", false).await.unwrap();
825 let org = db.create_org("org", "org").await.unwrap();
826 let channel = db.create_org_channel(org, "channel").await.unwrap();
827
828 let msg1_id = db
829 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
830 .await
831 .unwrap();
832 let msg2_id = db
833 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
834 .await
835 .unwrap();
836 let msg3_id = db
837 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
838 .await
839 .unwrap();
840 let msg4_id = db
841 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
842 .await
843 .unwrap();
844
845 assert_ne!(msg1_id, msg2_id);
846 assert_eq!(msg1_id, msg3_id);
847 assert_eq!(msg2_id, msg4_id);
848 }
849 }
850
851 #[tokio::test(flavor = "multi_thread")]
852 async fn test_create_access_tokens() {
853 let test_db = TestDb::postgres().await;
854 let db = test_db.db();
855 let user = db.create_user("the-user", false).await.unwrap();
856
857 db.create_access_token_hash(user, "h1", 3).await.unwrap();
858 db.create_access_token_hash(user, "h2", 3).await.unwrap();
859 assert_eq!(
860 db.get_access_token_hashes(user).await.unwrap(),
861 &["h2".to_string(), "h1".to_string()]
862 );
863
864 db.create_access_token_hash(user, "h3", 3).await.unwrap();
865 assert_eq!(
866 db.get_access_token_hashes(user).await.unwrap(),
867 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
868 );
869
870 db.create_access_token_hash(user, "h4", 3).await.unwrap();
871 assert_eq!(
872 db.get_access_token_hashes(user).await.unwrap(),
873 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
874 );
875
876 db.create_access_token_hash(user, "h5", 3).await.unwrap();
877 assert_eq!(
878 db.get_access_token_hashes(user).await.unwrap(),
879 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
880 );
881 }
882
883 #[test]
884 fn test_fuzzy_like_string() {
885 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
886 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
887 assert_eq!(fuzzy_like_string(" z "), "%z%");
888 }
889
890 #[tokio::test(flavor = "multi_thread")]
891 async fn test_fuzzy_search_users() {
892 let test_db = TestDb::postgres().await;
893 let db = test_db.db();
894 for github_login in [
895 "california",
896 "colorado",
897 "oregon",
898 "washington",
899 "florida",
900 "delaware",
901 "rhode-island",
902 ] {
903 db.create_user(github_login, false).await.unwrap();
904 }
905
906 assert_eq!(
907 fuzzy_search_user_names(db, "clr").await,
908 &["colorado", "california"]
909 );
910 assert_eq!(
911 fuzzy_search_user_names(db, "ro").await,
912 &["rhode-island", "colorado", "oregon"],
913 );
914
915 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
916 db.fuzzy_search_users(query, 10)
917 .await
918 .unwrap()
919 .into_iter()
920 .map(|user| user.github_login)
921 .collect::<Vec<_>>()
922 }
923 }
924
925 #[tokio::test(flavor = "multi_thread")]
926 async fn test_add_contacts() {
927 for test_db in [
928 TestDb::postgres().await,
929 TestDb::fake(Arc::new(gpui::executor::Background::new())),
930 ] {
931 let db = test_db.db();
932
933 let user_1 = db.create_user("user1", false).await.unwrap();
934 let user_2 = db.create_user("user2", false).await.unwrap();
935 let user_3 = db.create_user("user3", false).await.unwrap();
936
937 // User starts with no contacts
938 assert_eq!(
939 db.get_contacts(user_1).await.unwrap(),
940 Contacts {
941 current: vec![],
942 outgoing_requests: vec![],
943 incoming_requests: vec![],
944 },
945 );
946
947 // User requests a contact. Both users see the pending request.
948 db.send_contact_request(user_1, user_2).await.unwrap();
949 assert_eq!(
950 db.get_contacts(user_1).await.unwrap(),
951 Contacts {
952 current: vec![],
953 outgoing_requests: vec![user_2],
954 incoming_requests: vec![],
955 },
956 );
957 assert_eq!(
958 db.get_contacts(user_2).await.unwrap(),
959 Contacts {
960 current: vec![],
961 outgoing_requests: vec![],
962 incoming_requests: vec![IncomingContactRequest {
963 requester_id: user_1,
964 should_notify: true
965 }],
966 },
967 );
968
969 // User 2 dismisses the contact request notification without accepting or rejecting.
970 // We shouldn't notify them again.
971 db.dismiss_contact_request(user_1, user_2)
972 .await
973 .unwrap_err();
974 db.dismiss_contact_request(user_2, user_1).await.unwrap();
975 assert_eq!(
976 db.get_contacts(user_2).await.unwrap(),
977 Contacts {
978 current: vec![],
979 outgoing_requests: vec![],
980 incoming_requests: vec![IncomingContactRequest {
981 requester_id: user_1,
982 should_notify: false
983 }],
984 },
985 );
986
987 // User can't accept their own contact request
988 db.respond_to_contact_request(user_1, user_2, true)
989 .await
990 .unwrap_err();
991
992 // User accepts a contact request. Both users see the contact.
993 db.respond_to_contact_request(user_2, user_1, true)
994 .await
995 .unwrap();
996 assert_eq!(
997 db.get_contacts(user_1).await.unwrap(),
998 Contacts {
999 current: vec![user_2],
1000 outgoing_requests: vec![],
1001 incoming_requests: vec![],
1002 },
1003 );
1004 assert_eq!(
1005 db.get_contacts(user_2).await.unwrap(),
1006 Contacts {
1007 current: vec![user_1],
1008 outgoing_requests: vec![],
1009 incoming_requests: vec![],
1010 },
1011 );
1012
1013 // Users cannot re-request existing contacts.
1014 db.send_contact_request(user_1, user_2).await.unwrap_err();
1015 db.send_contact_request(user_2, user_1).await.unwrap_err();
1016
1017 // Users send each other concurrent contact requests and
1018 // see that they are immediately accepted.
1019 db.send_contact_request(user_1, user_3).await.unwrap();
1020 db.send_contact_request(user_3, user_1).await.unwrap();
1021 assert_eq!(
1022 db.get_contacts(user_1).await.unwrap(),
1023 Contacts {
1024 current: vec![user_2, user_3],
1025 outgoing_requests: vec![],
1026 incoming_requests: vec![],
1027 },
1028 );
1029 assert_eq!(
1030 db.get_contacts(user_3).await.unwrap(),
1031 Contacts {
1032 current: vec![user_1],
1033 outgoing_requests: vec![],
1034 incoming_requests: vec![],
1035 },
1036 );
1037
1038 // User declines a contact request. Both users see that it is gone.
1039 db.send_contact_request(user_2, user_3).await.unwrap();
1040 db.respond_to_contact_request(user_3, user_2, false)
1041 .await
1042 .unwrap();
1043 assert_eq!(
1044 db.get_contacts(user_2).await.unwrap(),
1045 Contacts {
1046 current: vec![user_1],
1047 outgoing_requests: vec![],
1048 incoming_requests: vec![],
1049 },
1050 );
1051 assert_eq!(
1052 db.get_contacts(user_3).await.unwrap(),
1053 Contacts {
1054 current: vec![user_1],
1055 outgoing_requests: vec![],
1056 incoming_requests: vec![],
1057 },
1058 );
1059 }
1060 }
1061
1062 pub struct TestDb {
1063 pub db: Option<Arc<dyn Db>>,
1064 pub url: String,
1065 }
1066
1067 impl TestDb {
1068 pub async fn postgres() -> Self {
1069 lazy_static! {
1070 static ref LOCK: Mutex<()> = Mutex::new(());
1071 }
1072
1073 let _guard = LOCK.lock();
1074 let mut rng = StdRng::from_entropy();
1075 let name = format!("zed-test-{}", rng.gen::<u128>());
1076 let url = format!("postgres://postgres@localhost/{}", name);
1077 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1078 Postgres::create_database(&url)
1079 .await
1080 .expect("failed to create test db");
1081 let db = PostgresDb::new(&url, 5).await.unwrap();
1082 let migrator = Migrator::new(migrations_path).await.unwrap();
1083 migrator.run(&db.pool).await.unwrap();
1084 Self {
1085 db: Some(Arc::new(db)),
1086 url,
1087 }
1088 }
1089
1090 pub fn fake(background: Arc<Background>) -> Self {
1091 Self {
1092 db: Some(Arc::new(FakeDb::new(background))),
1093 url: Default::default(),
1094 }
1095 }
1096
1097 pub fn db(&self) -> &Arc<dyn Db> {
1098 self.db.as_ref().unwrap()
1099 }
1100 }
1101
1102 impl Drop for TestDb {
1103 fn drop(&mut self) {
1104 if let Some(db) = self.db.take() {
1105 futures::executor::block_on(db.teardown(&self.url));
1106 }
1107 }
1108 }
1109
1110 pub struct FakeDb {
1111 background: Arc<Background>,
1112 users: Mutex<BTreeMap<UserId, User>>,
1113 next_user_id: Mutex<i32>,
1114 orgs: Mutex<BTreeMap<OrgId, Org>>,
1115 next_org_id: Mutex<i32>,
1116 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1117 channels: Mutex<BTreeMap<ChannelId, Channel>>,
1118 next_channel_id: Mutex<i32>,
1119 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1120 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1121 next_channel_message_id: Mutex<i32>,
1122 contacts: Mutex<Vec<FakeContact>>,
1123 }
1124
1125 #[derive(Debug)]
1126 struct FakeContact {
1127 requester_id: UserId,
1128 responder_id: UserId,
1129 accepted: bool,
1130 should_notify: bool,
1131 }
1132
1133 impl FakeDb {
1134 pub fn new(background: Arc<Background>) -> Self {
1135 Self {
1136 background,
1137 users: Default::default(),
1138 next_user_id: Mutex::new(1),
1139 orgs: Default::default(),
1140 next_org_id: Mutex::new(1),
1141 org_memberships: Default::default(),
1142 channels: Default::default(),
1143 next_channel_id: Mutex::new(1),
1144 channel_memberships: Default::default(),
1145 channel_messages: Default::default(),
1146 next_channel_message_id: Mutex::new(1),
1147 contacts: Default::default(),
1148 }
1149 }
1150 }
1151
1152 #[async_trait]
1153 impl Db for FakeDb {
1154 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1155 self.background.simulate_random_delay().await;
1156
1157 let mut users = self.users.lock();
1158 if let Some(user) = users
1159 .values()
1160 .find(|user| user.github_login == github_login)
1161 {
1162 Ok(user.id)
1163 } else {
1164 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1165 users.insert(
1166 user_id,
1167 User {
1168 id: user_id,
1169 github_login: github_login.to_string(),
1170 admin,
1171 },
1172 );
1173 Ok(user_id)
1174 }
1175 }
1176
1177 async fn get_all_users(&self) -> Result<Vec<User>> {
1178 unimplemented!()
1179 }
1180
1181 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1182 unimplemented!()
1183 }
1184
1185 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1186 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1187 }
1188
1189 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1190 self.background.simulate_random_delay().await;
1191 let users = self.users.lock();
1192 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1193 }
1194
1195 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1196 Ok(self
1197 .users
1198 .lock()
1199 .values()
1200 .find(|user| user.github_login == github_login)
1201 .cloned())
1202 }
1203
1204 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1205 unimplemented!()
1206 }
1207
1208 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1209 unimplemented!()
1210 }
1211
1212 async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
1213 self.background.simulate_random_delay().await;
1214 let mut current = Vec::new();
1215 let mut outgoing_requests = Vec::new();
1216 let mut incoming_requests = Vec::new();
1217
1218 for contact in self.contacts.lock().iter() {
1219 if contact.requester_id == id {
1220 if contact.accepted {
1221 current.push(contact.responder_id);
1222 } else {
1223 outgoing_requests.push(contact.responder_id);
1224 }
1225 } else if contact.responder_id == id {
1226 if contact.accepted {
1227 current.push(contact.requester_id);
1228 } else {
1229 incoming_requests.push(IncomingContactRequest {
1230 requester_id: contact.requester_id,
1231 should_notify: contact.should_notify,
1232 });
1233 }
1234 }
1235 }
1236
1237 Ok(Contacts {
1238 current,
1239 outgoing_requests,
1240 incoming_requests,
1241 })
1242 }
1243
1244 async fn send_contact_request(
1245 &self,
1246 requester_id: UserId,
1247 responder_id: UserId,
1248 ) -> Result<()> {
1249 let mut contacts = self.contacts.lock();
1250 for contact in contacts.iter_mut() {
1251 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1252 if contact.accepted {
1253 Err(anyhow!("contact already exists"))?;
1254 } else {
1255 Err(anyhow!("contact already requested"))?;
1256 }
1257 }
1258 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1259 if contact.accepted {
1260 Err(anyhow!("contact already exists"))?;
1261 } else {
1262 contact.accepted = true;
1263 return Ok(());
1264 }
1265 }
1266 }
1267 contacts.push(FakeContact {
1268 requester_id,
1269 responder_id,
1270 accepted: false,
1271 should_notify: true,
1272 });
1273 Ok(())
1274 }
1275
1276 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1277 self.contacts.lock().retain(|contact| {
1278 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1279 });
1280 Ok(())
1281 }
1282
1283 async fn dismiss_contact_request(
1284 &self,
1285 responder_id: UserId,
1286 requester_id: UserId,
1287 ) -> Result<()> {
1288 let mut contacts = self.contacts.lock();
1289 for contact in contacts.iter_mut() {
1290 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1291 if contact.accepted {
1292 return Err(anyhow!("contact already confirmed"));
1293 }
1294 contact.should_notify = false;
1295 return Ok(());
1296 }
1297 }
1298 Err(anyhow!("no such contact request"))
1299 }
1300
1301 async fn respond_to_contact_request(
1302 &self,
1303 responder_id: UserId,
1304 requester_id: UserId,
1305 accept: bool,
1306 ) -> Result<()> {
1307 let mut contacts = self.contacts.lock();
1308 for (ix, contact) in contacts.iter_mut().enumerate() {
1309 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1310 if contact.accepted {
1311 return Err(anyhow!("contact already confirmed"));
1312 }
1313 if accept {
1314 contact.accepted = true;
1315 } else {
1316 contacts.remove(ix);
1317 }
1318 return Ok(());
1319 }
1320 }
1321 Err(anyhow!("no such contact request"))
1322 }
1323
1324 async fn create_access_token_hash(
1325 &self,
1326 _user_id: UserId,
1327 _access_token_hash: &str,
1328 _max_access_token_count: usize,
1329 ) -> Result<()> {
1330 unimplemented!()
1331 }
1332
1333 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1334 unimplemented!()
1335 }
1336
1337 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1338 unimplemented!()
1339 }
1340
1341 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1342 self.background.simulate_random_delay().await;
1343 let mut orgs = self.orgs.lock();
1344 if orgs.values().any(|org| org.slug == slug) {
1345 Err(anyhow!("org already exists"))
1346 } else {
1347 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1348 orgs.insert(
1349 org_id,
1350 Org {
1351 id: org_id,
1352 name: name.to_string(),
1353 slug: slug.to_string(),
1354 },
1355 );
1356 Ok(org_id)
1357 }
1358 }
1359
1360 async fn add_org_member(
1361 &self,
1362 org_id: OrgId,
1363 user_id: UserId,
1364 is_admin: bool,
1365 ) -> Result<()> {
1366 self.background.simulate_random_delay().await;
1367 if !self.orgs.lock().contains_key(&org_id) {
1368 return Err(anyhow!("org does not exist"));
1369 }
1370 if !self.users.lock().contains_key(&user_id) {
1371 return Err(anyhow!("user does not exist"));
1372 }
1373
1374 self.org_memberships
1375 .lock()
1376 .entry((org_id, user_id))
1377 .or_insert(is_admin);
1378 Ok(())
1379 }
1380
1381 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1382 self.background.simulate_random_delay().await;
1383 if !self.orgs.lock().contains_key(&org_id) {
1384 return Err(anyhow!("org does not exist"));
1385 }
1386
1387 let mut channels = self.channels.lock();
1388 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1389 channels.insert(
1390 channel_id,
1391 Channel {
1392 id: channel_id,
1393 name: name.to_string(),
1394 owner_id: org_id.0,
1395 owner_is_user: false,
1396 },
1397 );
1398 Ok(channel_id)
1399 }
1400
1401 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1402 self.background.simulate_random_delay().await;
1403 Ok(self
1404 .channels
1405 .lock()
1406 .values()
1407 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1408 .cloned()
1409 .collect())
1410 }
1411
1412 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1413 self.background.simulate_random_delay().await;
1414 let channels = self.channels.lock();
1415 let memberships = self.channel_memberships.lock();
1416 Ok(channels
1417 .values()
1418 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1419 .cloned()
1420 .collect())
1421 }
1422
1423 async fn can_user_access_channel(
1424 &self,
1425 user_id: UserId,
1426 channel_id: ChannelId,
1427 ) -> Result<bool> {
1428 self.background.simulate_random_delay().await;
1429 Ok(self
1430 .channel_memberships
1431 .lock()
1432 .contains_key(&(channel_id, user_id)))
1433 }
1434
1435 async fn add_channel_member(
1436 &self,
1437 channel_id: ChannelId,
1438 user_id: UserId,
1439 is_admin: bool,
1440 ) -> Result<()> {
1441 self.background.simulate_random_delay().await;
1442 if !self.channels.lock().contains_key(&channel_id) {
1443 return Err(anyhow!("channel does not exist"));
1444 }
1445 if !self.users.lock().contains_key(&user_id) {
1446 return Err(anyhow!("user does not exist"));
1447 }
1448
1449 self.channel_memberships
1450 .lock()
1451 .entry((channel_id, user_id))
1452 .or_insert(is_admin);
1453 Ok(())
1454 }
1455
1456 async fn create_channel_message(
1457 &self,
1458 channel_id: ChannelId,
1459 sender_id: UserId,
1460 body: &str,
1461 timestamp: OffsetDateTime,
1462 nonce: u128,
1463 ) -> Result<MessageId> {
1464 self.background.simulate_random_delay().await;
1465 if !self.channels.lock().contains_key(&channel_id) {
1466 return Err(anyhow!("channel does not exist"));
1467 }
1468 if !self.users.lock().contains_key(&sender_id) {
1469 return Err(anyhow!("user does not exist"));
1470 }
1471
1472 let mut messages = self.channel_messages.lock();
1473 if let Some(message) = messages
1474 .values()
1475 .find(|message| message.nonce.as_u128() == nonce)
1476 {
1477 Ok(message.id)
1478 } else {
1479 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1480 messages.insert(
1481 message_id,
1482 ChannelMessage {
1483 id: message_id,
1484 channel_id,
1485 sender_id,
1486 body: body.to_string(),
1487 sent_at: timestamp,
1488 nonce: Uuid::from_u128(nonce),
1489 },
1490 );
1491 Ok(message_id)
1492 }
1493 }
1494
1495 async fn get_channel_messages(
1496 &self,
1497 channel_id: ChannelId,
1498 count: usize,
1499 before_id: Option<MessageId>,
1500 ) -> Result<Vec<ChannelMessage>> {
1501 let mut messages = self
1502 .channel_messages
1503 .lock()
1504 .values()
1505 .rev()
1506 .filter(|message| {
1507 message.channel_id == channel_id
1508 && message.id < before_id.unwrap_or(MessageId::MAX)
1509 })
1510 .take(count)
1511 .cloned()
1512 .collect::<Vec<_>>();
1513 messages.sort_unstable_by_key(|message| message.id);
1514 Ok(messages)
1515 }
1516
1517 async fn teardown(&self, _: &str) {}
1518 }
1519}