1use serde::Serialize;
2use sqlx::{FromRow, Result};
3
4pub use async_sqlx_session::PostgresSessionStore as SessionStore;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6
7pub struct Db(pub sqlx::PgPool);
8
9#[derive(Debug, FromRow, Serialize)]
10pub struct User {
11 id: i32,
12 pub github_login: String,
13 pub admin: bool,
14}
15
16#[derive(Debug, FromRow, Serialize)]
17pub struct Signup {
18 id: i32,
19 pub github_login: String,
20 pub email_address: String,
21 pub about: String,
22}
23
24#[derive(Debug, FromRow)]
25pub struct ChannelMessage {
26 id: i32,
27 sender_id: i32,
28 body: String,
29 sent_at: i64,
30}
31
32#[derive(Clone, Copy)]
33pub struct UserId(pub i32);
34
35#[derive(Clone, Copy)]
36pub struct OrgId(pub i32);
37
38#[derive(Clone, Copy)]
39pub struct ChannelId(pub i32);
40
41#[derive(Clone, Copy)]
42pub struct SignupId(pub i32);
43
44#[derive(Clone, Copy)]
45pub struct MessageId(pub i32);
46
47impl Db {
48 // signups
49
50 pub async fn create_signup(
51 &self,
52 github_login: &str,
53 email_address: &str,
54 about: &str,
55 ) -> Result<SignupId> {
56 let query = "
57 INSERT INTO signups (github_login, email_address, about)
58 VALUES ($1, $2, $3)
59 RETURNING id
60 ";
61 sqlx::query_scalar(query)
62 .bind(github_login)
63 .bind(email_address)
64 .bind(about)
65 .fetch_one(&self.0)
66 .await
67 .map(SignupId)
68 }
69
70 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
71 let query = "SELECT * FROM users ORDER BY github_login ASC";
72 sqlx::query_as(query).fetch_all(&self.0).await
73 }
74
75 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.0)
80 .await
81 .map(drop)
82 }
83
84 // users
85
86 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
87 let query = "
88 INSERT INTO users (github_login, admin)
89 VALUES ($1, $2)
90 RETURNING id
91 ";
92 sqlx::query_scalar(query)
93 .bind(github_login)
94 .bind(admin)
95 .fetch_one(&self.0)
96 .await
97 .map(UserId)
98 }
99
100 pub async fn get_all_users(&self) -> Result<Vec<User>> {
101 let query = "SELECT * FROM users ORDER BY github_login ASC";
102 sqlx::query_as(query).fetch_all(&self.0).await
103 }
104
105 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
106 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
107 sqlx::query_as(query)
108 .bind(github_login)
109 .fetch_optional(&self.0)
110 .await
111 }
112
113 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
114 let query = "UPDATE users SET admin = $1 WHERE id = $2";
115 sqlx::query(query)
116 .bind(is_admin)
117 .bind(id.0)
118 .execute(&self.0)
119 .await
120 .map(drop)
121 }
122
123 pub async fn delete_user(&self, id: UserId) -> Result<()> {
124 let query = "DELETE FROM users WHERE id = $1;";
125 sqlx::query(query)
126 .bind(id.0)
127 .execute(&self.0)
128 .await
129 .map(drop)
130 }
131
132 // access tokens
133
134 pub async fn create_access_token_hash(
135 &self,
136 user_id: UserId,
137 access_token_hash: String,
138 ) -> Result<()> {
139 let query = "
140 INSERT INTO access_tokens (user_id, hash)
141 VALUES ($1, $2)
142 ";
143 sqlx::query(query)
144 .bind(user_id.0 as i32)
145 .bind(access_token_hash)
146 .execute(&self.0)
147 .await
148 .map(drop)
149 }
150
151 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
152 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
153 sqlx::query_scalar::<_, String>(query)
154 .bind(user_id.0 as i32)
155 .fetch_all(&self.0)
156 .await
157 }
158
159 // orgs
160
161 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
162 let query = "
163 INSERT INTO orgs (name, slug)
164 VALUES ($1, $2)
165 RETURNING id
166 ";
167 sqlx::query_scalar(query)
168 .bind(name)
169 .bind(slug)
170 .fetch_one(&self.0)
171 .await
172 .map(OrgId)
173 }
174
175 pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
176 let query = "
177 INSERT INTO org_memberships (org_id, user_id)
178 VALUES ($1, $2)
179 ";
180 sqlx::query(query)
181 .bind(org_id.0)
182 .bind(user_id.0)
183 .execute(&self.0)
184 .await
185 .map(drop)
186 }
187
188 // channels
189
190 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
191 let query = "
192 INSERT INTO channels (owner_id, owner_is_user, name)
193 VALUES ($1, false, $2)
194 RETURNING id
195 ";
196 sqlx::query_scalar(query)
197 .bind(org_id.0)
198 .bind(name)
199 .fetch_one(&self.0)
200 .await
201 .map(ChannelId)
202 }
203
204 pub async fn add_channel_member(
205 &self,
206 channel_id: ChannelId,
207 user_id: UserId,
208 is_admin: bool,
209 ) -> Result<()> {
210 let query = "
211 INSERT INTO channel_memberships (channel_id, user_id, admin)
212 VALUES ($1, $2, $3)
213 ";
214 sqlx::query(query)
215 .bind(channel_id.0)
216 .bind(user_id.0)
217 .bind(is_admin)
218 .execute(&self.0)
219 .await
220 .map(drop)
221 }
222
223 // messages
224
225 pub async fn create_channel_message(
226 &self,
227 channel_id: ChannelId,
228 sender_id: UserId,
229 body: &str,
230 ) -> Result<MessageId> {
231 let query = "
232 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
233 VALUES ($1, $2, $3, NOW()::timestamp)
234 RETURNING id
235 ";
236 sqlx::query_scalar(query)
237 .bind(channel_id.0)
238 .bind(sender_id.0)
239 .bind(body)
240 .fetch_one(&self.0)
241 .await
242 .map(MessageId)
243 }
244
245 pub async fn get_recent_channel_messages(
246 &self,
247 channel_id: ChannelId,
248 count: usize,
249 ) -> Result<Vec<ChannelMessage>> {
250 let query = "
251 SELECT id, sender_id, body, sent_at
252 FROM channel_messages
253 WHERE channel_id = $1
254 LIMIT $2
255 ";
256 sqlx::query_as(query)
257 .bind(channel_id.0)
258 .bind(count as i64)
259 .fetch_all(&self.0)
260 .await
261 }
262}
263
264impl std::ops::Deref for Db {
265 type Target = sqlx::PgPool;
266
267 fn deref(&self) -> &Self::Target {
268 &self.0
269 }
270}
271
272impl User {
273 pub fn id(&self) -> UserId {
274 UserId(self.id)
275 }
276}