1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM users ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 #[allow(unused)] // Help rust-analyzer
88 #[cfg(any(test, feature = "seed-support"))]
89 pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
90 test_support!(self, {
91 let query = "
92 SELECT id
93 FROM users
94 WHERE github_login = $1
95 ";
96 sqlx::query_scalar(query)
97 .bind(github_login)
98 .fetch_optional(&self.pool)
99 .await
100 })
101 }
102
103 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104 test_support!(self, {
105 let query = "
106 INSERT INTO users (github_login, admin)
107 VALUES ($1, $2)
108 RETURNING id
109 ";
110 sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)
116 })
117 }
118
119 pub async fn get_all_users(&self) -> Result<Vec<User>> {
120 test_support!(self, {
121 let query = "SELECT * FROM users ORDER BY github_login ASC";
122 sqlx::query_as(query).fetch_all(&self.pool).await
123 })
124 }
125
126 pub async fn get_users_by_ids(
127 &self,
128 requester_id: UserId,
129 ids: impl Iterator<Item = UserId>,
130 ) -> Result<Vec<User>> {
131 test_support!(self, {
132 // Only return users that are in a common channel with the requesting user.
133 let query = "
134 SELECT users.*
135 FROM
136 users, channel_memberships
137 WHERE
138 users.id = ANY ($1) AND
139 channel_memberships.user_id = users.id AND
140 channel_memberships.channel_id IN (
141 SELECT channel_id
142 FROM channel_memberships
143 WHERE channel_memberships.user_id = $2
144 )
145 ";
146
147 sqlx::query_as(query)
148 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
149 .bind(requester_id)
150 .fetch_all(&self.pool)
151 .await
152 })
153 }
154
155 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
156 test_support!(self, {
157 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
158 sqlx::query_as(query)
159 .bind(github_login)
160 .fetch_optional(&self.pool)
161 .await
162 })
163 }
164
165 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
166 test_support!(self, {
167 let query = "UPDATE users SET admin = $1 WHERE id = $2";
168 sqlx::query(query)
169 .bind(is_admin)
170 .bind(id.0)
171 .execute(&self.pool)
172 .await
173 .map(drop)
174 })
175 }
176
177 pub async fn delete_user(&self, id: UserId) -> Result<()> {
178 test_support!(self, {
179 let query = "DELETE FROM users WHERE id = $1;";
180 sqlx::query(query)
181 .bind(id.0)
182 .execute(&self.pool)
183 .await
184 .map(drop)
185 })
186 }
187
188 // access tokens
189
190 pub async fn create_access_token_hash(
191 &self,
192 user_id: UserId,
193 access_token_hash: String,
194 ) -> Result<()> {
195 test_support!(self, {
196 let query = "
197 INSERT INTO access_tokens (user_id, hash)
198 VALUES ($1, $2)
199 ";
200 sqlx::query(query)
201 .bind(user_id.0)
202 .bind(access_token_hash)
203 .execute(&self.pool)
204 .await
205 .map(drop)
206 })
207 }
208
209 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
210 test_support!(self, {
211 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
212 sqlx::query_scalar(query)
213 .bind(user_id.0)
214 .fetch_all(&self.pool)
215 .await
216 })
217 }
218
219 // orgs
220
221 #[allow(unused)] // Help rust-analyzer
222 #[cfg(any(test, feature = "seed-support"))]
223 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
224 test_support!(self, {
225 let query = "
226 SELECT *
227 FROM orgs
228 WHERE slug = $1
229 ";
230 sqlx::query_as(query)
231 .bind(slug)
232 .fetch_optional(&self.pool)
233 .await
234 })
235 }
236
237 #[cfg(any(test, feature = "seed-support"))]
238 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
239 test_support!(self, {
240 let query = "
241 INSERT INTO orgs (name, slug)
242 VALUES ($1, $2)
243 RETURNING id
244 ";
245 sqlx::query_scalar(query)
246 .bind(name)
247 .bind(slug)
248 .fetch_one(&self.pool)
249 .await
250 .map(OrgId)
251 })
252 }
253
254 #[cfg(any(test, feature = "seed-support"))]
255 pub async fn add_org_member(
256 &self,
257 org_id: OrgId,
258 user_id: UserId,
259 is_admin: bool,
260 ) -> Result<()> {
261 test_support!(self, {
262 let query = "
263 INSERT INTO org_memberships (org_id, user_id, admin)
264 VALUES ($1, $2, $3)
265 ON CONFLICT DO NOTHING
266 ";
267 sqlx::query(query)
268 .bind(org_id.0)
269 .bind(user_id.0)
270 .bind(is_admin)
271 .execute(&self.pool)
272 .await
273 .map(drop)
274 })
275 }
276
277 // channels
278
279 #[cfg(any(test, feature = "seed-support"))]
280 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
281 test_support!(self, {
282 let query = "
283 INSERT INTO channels (owner_id, owner_is_user, name)
284 VALUES ($1, false, $2)
285 RETURNING id
286 ";
287 sqlx::query_scalar(query)
288 .bind(org_id.0)
289 .bind(name)
290 .fetch_one(&self.pool)
291 .await
292 .map(ChannelId)
293 })
294 }
295
296 #[allow(unused)] // Help rust-analyzer
297 #[cfg(any(test, feature = "seed-support"))]
298 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
299 test_support!(self, {
300 let query = "
301 SELECT *
302 FROM channels
303 WHERE
304 channels.owner_is_user = false AND
305 channels.owner_id = $1
306 ";
307 sqlx::query_as(query)
308 .bind(org_id.0)
309 .fetch_all(&self.pool)
310 .await
311 })
312 }
313
314 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
315 test_support!(self, {
316 let query = "
317 SELECT
318 channels.id, channels.name
319 FROM
320 channel_memberships, channels
321 WHERE
322 channel_memberships.user_id = $1 AND
323 channel_memberships.channel_id = channels.id
324 ";
325 sqlx::query_as(query)
326 .bind(user_id.0)
327 .fetch_all(&self.pool)
328 .await
329 })
330 }
331
332 pub async fn can_user_access_channel(
333 &self,
334 user_id: UserId,
335 channel_id: ChannelId,
336 ) -> Result<bool> {
337 test_support!(self, {
338 let query = "
339 SELECT id
340 FROM channel_memberships
341 WHERE user_id = $1 AND channel_id = $2
342 LIMIT 1
343 ";
344 sqlx::query_scalar::<_, i32>(query)
345 .bind(user_id.0)
346 .bind(channel_id.0)
347 .fetch_optional(&self.pool)
348 .await
349 .map(|e| e.is_some())
350 })
351 }
352
353 #[cfg(any(test, feature = "seed-support"))]
354 pub async fn add_channel_member(
355 &self,
356 channel_id: ChannelId,
357 user_id: UserId,
358 is_admin: bool,
359 ) -> Result<()> {
360 test_support!(self, {
361 let query = "
362 INSERT INTO channel_memberships (channel_id, user_id, admin)
363 VALUES ($1, $2, $3)
364 ON CONFLICT DO NOTHING
365 ";
366 sqlx::query(query)
367 .bind(channel_id.0)
368 .bind(user_id.0)
369 .bind(is_admin)
370 .execute(&self.pool)
371 .await
372 .map(drop)
373 })
374 }
375
376 // messages
377
378 pub async fn create_channel_message(
379 &self,
380 channel_id: ChannelId,
381 sender_id: UserId,
382 body: &str,
383 timestamp: OffsetDateTime,
384 ) -> Result<MessageId> {
385 test_support!(self, {
386 let query = "
387 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
388 VALUES ($1, $2, $3, $4)
389 RETURNING id
390 ";
391 sqlx::query_scalar(query)
392 .bind(channel_id.0)
393 .bind(sender_id.0)
394 .bind(body)
395 .bind(timestamp)
396 .fetch_one(&self.pool)
397 .await
398 .map(MessageId)
399 })
400 }
401
402 pub async fn get_channel_messages(
403 &self,
404 channel_id: ChannelId,
405 count: usize,
406 before_id: Option<MessageId>,
407 ) -> Result<Vec<ChannelMessage>> {
408 test_support!(self, {
409 let query = r#"
410 SELECT * FROM (
411 SELECT
412 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
413 FROM
414 channel_messages
415 WHERE
416 channel_id = $1 AND
417 id < $2
418 ORDER BY id DESC
419 LIMIT $3
420 ) as recent_messages
421 ORDER BY id ASC
422 "#;
423 sqlx::query_as(query)
424 .bind(channel_id.0)
425 .bind(before_id.unwrap_or(MessageId::MAX))
426 .bind(count as i64)
427 .fetch_all(&self.pool)
428 .await
429 })
430 }
431}
432
433macro_rules! id_type {
434 ($name:ident) => {
435 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
436 #[sqlx(transparent)]
437 #[serde(transparent)]
438 pub struct $name(pub i32);
439
440 impl $name {
441 #[allow(unused)]
442 pub const MAX: Self = Self(i32::MAX);
443
444 #[allow(unused)]
445 pub fn from_proto(value: u64) -> Self {
446 Self(value as i32)
447 }
448
449 #[allow(unused)]
450 pub fn to_proto(&self) -> u64 {
451 self.0 as u64
452 }
453 }
454 };
455}
456
457id_type!(UserId);
458#[derive(Debug, FromRow, Serialize)]
459pub struct User {
460 pub id: UserId,
461 pub github_login: String,
462 pub admin: bool,
463}
464
465id_type!(OrgId);
466#[derive(FromRow)]
467pub struct Org {
468 pub id: OrgId,
469 pub name: String,
470 pub slug: String,
471}
472
473id_type!(SignupId);
474#[derive(Debug, FromRow, Serialize)]
475pub struct Signup {
476 pub id: SignupId,
477 pub github_login: String,
478 pub email_address: String,
479 pub about: String,
480}
481
482id_type!(ChannelId);
483#[derive(Debug, FromRow, Serialize)]
484pub struct Channel {
485 pub id: ChannelId,
486 pub name: String,
487}
488
489id_type!(MessageId);
490#[derive(Debug, FromRow)]
491pub struct ChannelMessage {
492 pub id: MessageId,
493 pub sender_id: UserId,
494 pub body: String,
495 pub sent_at: OffsetDateTime,
496}
497
498#[cfg(test)]
499pub mod tests {
500 use super::*;
501 use rand::prelude::*;
502 use sqlx::{
503 migrate::{MigrateDatabase, Migrator},
504 Postgres,
505 };
506 use std::path::Path;
507
508 pub struct TestDb {
509 pub db: Db,
510 pub name: String,
511 pub url: String,
512 }
513
514 impl TestDb {
515 pub fn new() -> Self {
516 // Enable tests to run in parallel by serializing the creation of each test database.
517 lazy_static::lazy_static! {
518 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
519 }
520
521 let mut rng = StdRng::from_entropy();
522 let name = format!("zed-test-{}", rng.gen::<u128>());
523 let url = format!("postgres://postgres@localhost/{}", name);
524 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
525 let db = block_on(async {
526 {
527 let _lock = DB_CREATION.lock();
528 Postgres::create_database(&url)
529 .await
530 .expect("failed to create test db");
531 }
532 let mut db = Db::new(&url, 5).await.unwrap();
533 db.test_mode = true;
534 let migrator = Migrator::new(migrations_path).await.unwrap();
535 migrator.run(&db.pool).await.unwrap();
536 db
537 });
538
539 Self { db, name, url }
540 }
541
542 pub fn db(&self) -> &Db {
543 &self.db
544 }
545 }
546
547 impl Drop for TestDb {
548 fn drop(&mut self) {
549 block_on(async {
550 let query = "
551 SELECT pg_terminate_backend(pg_stat_activity.pid)
552 FROM pg_stat_activity
553 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
554 ";
555 sqlx::query(query)
556 .bind(&self.name)
557 .execute(&self.db.pool)
558 .await
559 .unwrap();
560 self.db.pool.close().await;
561 Postgres::drop_database(&self.url).await.unwrap();
562 });
563 }
564 }
565
566 #[gpui::test]
567 async fn test_recent_channel_messages() {
568 let test_db = TestDb::new();
569 let db = test_db.db();
570 let user = db.create_user("user", false).await.unwrap();
571 let org = db.create_org("org", "org").await.unwrap();
572 let channel = db.create_org_channel(org, "channel").await.unwrap();
573 for i in 0..10 {
574 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
575 .await
576 .unwrap();
577 }
578
579 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
580 assert_eq!(
581 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
582 ["5", "6", "7", "8", "9"]
583 );
584
585 let prev_messages = db
586 .get_channel_messages(channel, 4, Some(messages[0].id))
587 .await
588 .unwrap();
589 assert_eq!(
590 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
591 ["1", "2", "3", "4"]
592 );
593 }
594}