1use serde::Serialize;
2use sqlx::{FromRow, Result};
3use time::OffsetDateTime;
4
5pub use async_sqlx_session::PostgresSessionStore as SessionStore;
6pub use sqlx::postgres::PgPoolOptions as DbOptions;
7
8pub struct Db(pub sqlx::PgPool);
9
10#[derive(Debug, FromRow, Serialize)]
11pub struct User {
12 pub id: UserId,
13 pub github_login: String,
14 pub admin: bool,
15}
16
17#[derive(Debug, FromRow, Serialize)]
18pub struct Signup {
19 pub id: SignupId,
20 pub github_login: String,
21 pub email_address: String,
22 pub about: String,
23}
24
25#[derive(Debug, FromRow, Serialize)]
26pub struct Channel {
27 pub id: ChannelId,
28 pub name: String,
29}
30
31#[derive(Debug, FromRow)]
32pub struct ChannelMessage {
33 pub id: MessageId,
34 pub sender_id: UserId,
35 pub body: String,
36 pub sent_at: OffsetDateTime,
37}
38
39impl Db {
40 // signups
41
42 pub async fn create_signup(
43 &self,
44 github_login: &str,
45 email_address: &str,
46 about: &str,
47 ) -> Result<SignupId> {
48 let query = "
49 INSERT INTO signups (github_login, email_address, about)
50 VALUES ($1, $2, $3)
51 RETURNING id
52 ";
53 sqlx::query_scalar(query)
54 .bind(github_login)
55 .bind(email_address)
56 .bind(about)
57 .fetch_one(&self.0)
58 .await
59 .map(SignupId)
60 }
61
62 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
63 let query = "SELECT * FROM users ORDER BY github_login ASC";
64 sqlx::query_as(query).fetch_all(&self.0).await
65 }
66
67 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
68 let query = "DELETE FROM signups WHERE id = $1";
69 sqlx::query(query)
70 .bind(id.0)
71 .execute(&self.0)
72 .await
73 .map(drop)
74 }
75
76 // users
77
78 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
79 let query = "
80 INSERT INTO users (github_login, admin)
81 VALUES ($1, $2)
82 RETURNING id
83 ";
84 sqlx::query_scalar(query)
85 .bind(github_login)
86 .bind(admin)
87 .fetch_one(&self.0)
88 .await
89 .map(UserId)
90 }
91
92 pub async fn get_all_users(&self) -> Result<Vec<User>> {
93 let query = "SELECT * FROM users ORDER BY github_login ASC";
94 sqlx::query_as(query).fetch_all(&self.0).await
95 }
96
97 pub async fn get_users_by_ids(
98 &self,
99 requester_id: UserId,
100 ids: impl Iterator<Item = UserId>,
101 ) -> Result<Vec<User>> {
102 // Only return users that are in a common channel with the requesting user.
103 let query = "
104 SELECT users.*
105 FROM
106 users, channel_memberships
107 WHERE
108 users.id IN $1 AND
109 channel_memberships.user_id = users.id AND
110 channel_memberships.channel_id IN (
111 SELECT channel_id
112 FROM channel_memberships
113 WHERE channel_memberships.user_id = $2
114 )
115 ";
116
117 sqlx::query_as(query)
118 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
119 .bind(requester_id)
120 .fetch_all(&self.0)
121 .await
122 }
123
124 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
125 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
126 sqlx::query_as(query)
127 .bind(github_login)
128 .fetch_optional(&self.0)
129 .await
130 }
131
132 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
133 let query = "UPDATE users SET admin = $1 WHERE id = $2";
134 sqlx::query(query)
135 .bind(is_admin)
136 .bind(id.0)
137 .execute(&self.0)
138 .await
139 .map(drop)
140 }
141
142 pub async fn delete_user(&self, id: UserId) -> Result<()> {
143 let query = "DELETE FROM users WHERE id = $1;";
144 sqlx::query(query)
145 .bind(id.0)
146 .execute(&self.0)
147 .await
148 .map(drop)
149 }
150
151 // access tokens
152
153 pub async fn create_access_token_hash(
154 &self,
155 user_id: UserId,
156 access_token_hash: String,
157 ) -> Result<()> {
158 let query = "
159 INSERT INTO access_tokens (user_id, hash)
160 VALUES ($1, $2)
161 ";
162 sqlx::query(query)
163 .bind(user_id.0)
164 .bind(access_token_hash)
165 .execute(&self.0)
166 .await
167 .map(drop)
168 }
169
170 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
171 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
172 sqlx::query_scalar(query)
173 .bind(user_id.0)
174 .fetch_all(&self.0)
175 .await
176 }
177
178 // orgs
179
180 #[cfg(test)]
181 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
182 let query = "
183 INSERT INTO orgs (name, slug)
184 VALUES ($1, $2)
185 RETURNING id
186 ";
187 sqlx::query_scalar(query)
188 .bind(name)
189 .bind(slug)
190 .fetch_one(&self.0)
191 .await
192 .map(OrgId)
193 }
194
195 #[cfg(test)]
196 pub async fn add_org_member(
197 &self,
198 org_id: OrgId,
199 user_id: UserId,
200 is_admin: bool,
201 ) -> Result<()> {
202 let query = "
203 INSERT INTO org_memberships (org_id, user_id, admin)
204 VALUES ($1, $2, $3)
205 ";
206 sqlx::query(query)
207 .bind(org_id.0)
208 .bind(user_id.0)
209 .bind(is_admin)
210 .execute(&self.0)
211 .await
212 .map(drop)
213 }
214
215 // channels
216
217 #[cfg(test)]
218 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
219 let query = "
220 INSERT INTO channels (owner_id, owner_is_user, name)
221 VALUES ($1, false, $2)
222 RETURNING id
223 ";
224 sqlx::query_scalar(query)
225 .bind(org_id.0)
226 .bind(name)
227 .fetch_one(&self.0)
228 .await
229 .map(ChannelId)
230 }
231
232 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
233 let query = "
234 SELECT
235 channels.id, channels.name
236 FROM
237 channel_memberships, channels
238 WHERE
239 channel_memberships.user_id = $1 AND
240 channel_memberships.channel_id = channels.id
241 ";
242 sqlx::query_as(query)
243 .bind(user_id.0)
244 .fetch_all(&self.0)
245 .await
246 }
247
248 pub async fn can_user_access_channel(
249 &self,
250 user_id: UserId,
251 channel_id: ChannelId,
252 ) -> Result<bool> {
253 let query = "
254 SELECT id
255 FROM channel_memberships
256 WHERE user_id = $1 AND channel_id = $2
257 LIMIT 1
258 ";
259 sqlx::query_scalar::<_, i32>(query)
260 .bind(user_id.0)
261 .bind(channel_id.0)
262 .fetch_optional(&self.0)
263 .await
264 .map(|e| e.is_some())
265 }
266
267 #[cfg(test)]
268 pub async fn add_channel_member(
269 &self,
270 channel_id: ChannelId,
271 user_id: UserId,
272 is_admin: bool,
273 ) -> Result<()> {
274 let query = "
275 INSERT INTO channel_memberships (channel_id, user_id, admin)
276 VALUES ($1, $2, $3)
277 ";
278 sqlx::query(query)
279 .bind(channel_id.0)
280 .bind(user_id.0)
281 .bind(is_admin)
282 .execute(&self.0)
283 .await
284 .map(drop)
285 }
286
287 // messages
288
289 pub async fn create_channel_message(
290 &self,
291 channel_id: ChannelId,
292 sender_id: UserId,
293 body: &str,
294 timestamp: OffsetDateTime,
295 ) -> Result<MessageId> {
296 let query = "
297 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
298 VALUES ($1, $2, $3, $4)
299 RETURNING id
300 ";
301 sqlx::query_scalar(query)
302 .bind(channel_id.0)
303 .bind(sender_id.0)
304 .bind(body)
305 .bind(timestamp)
306 .fetch_one(&self.0)
307 .await
308 .map(MessageId)
309 }
310
311 pub async fn get_recent_channel_messages(
312 &self,
313 channel_id: ChannelId,
314 count: usize,
315 ) -> Result<Vec<ChannelMessage>> {
316 let query = r#"
317 SELECT
318 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
319 FROM
320 channel_messages
321 WHERE
322 channel_id = $1
323 LIMIT $2
324 "#;
325 sqlx::query_as(query)
326 .bind(channel_id.0)
327 .bind(count as i64)
328 .fetch_all(&self.0)
329 .await
330 }
331}
332
333impl std::ops::Deref for Db {
334 type Target = sqlx::PgPool;
335
336 fn deref(&self) -> &Self::Target {
337 &self.0
338 }
339}
340
341macro_rules! id_type {
342 ($name:ident) => {
343 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
344 #[sqlx(transparent)]
345 #[serde(transparent)]
346 pub struct $name(pub i32);
347
348 impl $name {
349 #[allow(unused)]
350 pub fn from_proto(value: u64) -> Self {
351 Self(value as i32)
352 }
353
354 #[allow(unused)]
355 pub fn to_proto(&self) -> u64 {
356 self.0 as u64
357 }
358 }
359 };
360}
361
362id_type!(UserId);
363id_type!(OrgId);
364id_type!(ChannelId);
365id_type!(SignupId);
366id_type!(MessageId);