1use super::*;
2use time::OffsetDateTime;
3
4impl Database {
5 pub async fn join_channel_chat(
6 &self,
7 channel_id: ChannelId,
8 connection_id: ConnectionId,
9 user_id: UserId,
10 ) -> Result<()> {
11 self.transaction(|tx| async move {
12 self.check_user_is_channel_member(channel_id, user_id, &*tx)
13 .await?;
14 channel_chat_participant::ActiveModel {
15 id: ActiveValue::NotSet,
16 channel_id: ActiveValue::Set(channel_id),
17 user_id: ActiveValue::Set(user_id),
18 connection_id: ActiveValue::Set(connection_id.id as i32),
19 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
20 }
21 .insert(&*tx)
22 .await?;
23 Ok(())
24 })
25 .await
26 }
27
28 pub async fn channel_chat_connection_lost(
29 &self,
30 connection_id: ConnectionId,
31 tx: &DatabaseTransaction,
32 ) -> Result<()> {
33 channel_chat_participant::Entity::delete_many()
34 .filter(
35 Condition::all()
36 .add(
37 channel_chat_participant::Column::ConnectionServerId
38 .eq(connection_id.owner_id),
39 )
40 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
41 )
42 .exec(tx)
43 .await?;
44 Ok(())
45 }
46
47 pub async fn leave_channel_chat(
48 &self,
49 channel_id: ChannelId,
50 connection_id: ConnectionId,
51 _user_id: UserId,
52 ) -> Result<()> {
53 self.transaction(|tx| async move {
54 channel_chat_participant::Entity::delete_many()
55 .filter(
56 Condition::all()
57 .add(
58 channel_chat_participant::Column::ConnectionServerId
59 .eq(connection_id.owner_id),
60 )
61 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
62 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
63 )
64 .exec(&*tx)
65 .await?;
66
67 Ok(())
68 })
69 .await
70 }
71
72 pub async fn get_channel_messages(
73 &self,
74 channel_id: ChannelId,
75 user_id: UserId,
76 count: usize,
77 before_message_id: Option<MessageId>,
78 ) -> Result<Vec<proto::ChannelMessage>> {
79 self.transaction(|tx| async move {
80 self.check_user_is_channel_member(channel_id, user_id, &*tx)
81 .await?;
82
83 let mut condition =
84 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
85
86 if let Some(before_message_id) = before_message_id {
87 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
88 }
89
90 let mut rows = channel_message::Entity::find()
91 .filter(condition)
92 .limit(count as u64)
93 .stream(&*tx)
94 .await?;
95
96 let mut max_id = None;
97 let mut messages = Vec::new();
98 while let Some(row) = rows.next().await {
99 let row = row?;
100
101 max_assign(&mut max_id, row.id);
102
103 let nonce = row.nonce.as_u64_pair();
104 messages.push(proto::ChannelMessage {
105 id: row.id.to_proto(),
106 sender_id: row.sender_id.to_proto(),
107 body: row.body,
108 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
109 nonce: Some(proto::Nonce {
110 upper_half: nonce.0,
111 lower_half: nonce.1,
112 }),
113 });
114 }
115 drop(rows);
116
117 if let Some(max_id) = max_id {
118 let has_older_message = observed_channel_messages::Entity::find()
119 .filter(
120 observed_channel_messages::Column::UserId
121 .eq(user_id)
122 .and(observed_channel_messages::Column::ChannelId.eq(channel_id))
123 .and(observed_channel_messages::Column::ChannelMessageId.lt(max_id)),
124 )
125 .one(&*tx)
126 .await?
127 .is_some();
128
129 if has_older_message {
130 observed_channel_messages::Entity::update(
131 observed_channel_messages::ActiveModel {
132 user_id: ActiveValue::Unchanged(user_id),
133 channel_id: ActiveValue::Unchanged(channel_id),
134 channel_message_id: ActiveValue::Set(max_id),
135 },
136 )
137 .exec(&*tx)
138 .await?;
139 } else {
140 observed_channel_messages::Entity::insert(
141 observed_channel_messages::ActiveModel {
142 user_id: ActiveValue::Set(user_id),
143 channel_id: ActiveValue::Set(channel_id),
144 channel_message_id: ActiveValue::Set(max_id),
145 },
146 )
147 .on_conflict(
148 OnConflict::columns([
149 observed_channel_messages::Column::UserId,
150 observed_channel_messages::Column::ChannelId,
151 ])
152 .update_columns([observed_channel_messages::Column::ChannelMessageId])
153 .to_owned(),
154 )
155 .exec(&*tx)
156 .await?;
157 }
158 }
159
160 Ok(messages)
161 })
162 .await
163 }
164
165 pub async fn create_channel_message(
166 &self,
167 channel_id: ChannelId,
168 user_id: UserId,
169 body: &str,
170 timestamp: OffsetDateTime,
171 nonce: u128,
172 ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
173 self.transaction(|tx| async move {
174 let mut rows = channel_chat_participant::Entity::find()
175 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
176 .stream(&*tx)
177 .await?;
178
179 let mut is_participant = false;
180 let mut participant_connection_ids = Vec::new();
181 let mut participant_user_ids = Vec::new();
182 while let Some(row) = rows.next().await {
183 let row = row?;
184 if row.user_id == user_id {
185 is_participant = true;
186 }
187 participant_user_ids.push(row.user_id);
188 participant_connection_ids.push(row.connection());
189 }
190 drop(rows);
191
192 if !is_participant {
193 Err(anyhow!("not a chat participant"))?;
194 }
195
196 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
197 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
198
199 let message = channel_message::Entity::insert(channel_message::ActiveModel {
200 channel_id: ActiveValue::Set(channel_id),
201 sender_id: ActiveValue::Set(user_id),
202 body: ActiveValue::Set(body.to_string()),
203 sent_at: ActiveValue::Set(timestamp),
204 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
205 id: ActiveValue::NotSet,
206 })
207 .on_conflict(
208 OnConflict::column(channel_message::Column::Nonce)
209 .update_column(channel_message::Column::Nonce)
210 .to_owned(),
211 )
212 .exec(&*tx)
213 .await?;
214
215 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
216 enum QueryConnectionId {
217 ConnectionId,
218 }
219
220 // Observe this message for the sender
221 self.observe_channel_message_internal(
222 channel_id,
223 user_id,
224 message.last_insert_id,
225 &*tx,
226 )
227 .await?;
228
229 let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
230 channel_members.retain(|member| !participant_user_ids.contains(member));
231
232 Ok((
233 message.last_insert_id,
234 participant_connection_ids,
235 channel_members,
236 ))
237 })
238 .await
239 }
240
241 pub async fn observe_channel_message(
242 &self,
243 channel_id: ChannelId,
244 user_id: UserId,
245 message_id: MessageId,
246 ) -> Result<()> {
247 self.transaction(|tx| async move {
248 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
249 .await?;
250 Ok(())
251 })
252 .await
253 }
254
255 async fn observe_channel_message_internal(
256 &self,
257 channel_id: ChannelId,
258 user_id: UserId,
259 message_id: MessageId,
260 tx: &DatabaseTransaction,
261 ) -> Result<()> {
262 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
263 user_id: ActiveValue::Set(user_id),
264 channel_id: ActiveValue::Set(channel_id),
265 channel_message_id: ActiveValue::Set(message_id),
266 })
267 .on_conflict(
268 OnConflict::columns([
269 observed_channel_messages::Column::ChannelId,
270 observed_channel_messages::Column::UserId,
271 ])
272 .update_column(observed_channel_messages::Column::ChannelMessageId)
273 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
274 .to_owned(),
275 )
276 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
277 .exec_without_returning(&*tx)
278 .await?;
279 Ok(())
280 }
281
282 pub async fn unseen_channel_messages(
283 &self,
284 user_id: UserId,
285 channel_ids: &[ChannelId],
286 tx: &DatabaseTransaction,
287 ) -> Result<Vec<proto::UnseenChannelMessage>> {
288 let mut observed_messages_by_channel_id = HashMap::default();
289 let mut rows = observed_channel_messages::Entity::find()
290 .filter(observed_channel_messages::Column::UserId.eq(user_id))
291 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
292 .stream(&*tx)
293 .await?;
294
295 while let Some(row) = rows.next().await {
296 let row = row?;
297 observed_messages_by_channel_id.insert(row.channel_id, row);
298 }
299 drop(rows);
300 let mut values = String::new();
301 for id in channel_ids {
302 if !values.is_empty() {
303 values.push_str(", ");
304 }
305 write!(&mut values, "({})", id).unwrap();
306 }
307
308 if values.is_empty() {
309 return Ok(Default::default());
310 }
311
312 let sql = format!(
313 r#"
314 SELECT
315 *
316 FROM (
317 SELECT
318 *,
319 row_number() OVER (
320 PARTITION BY channel_id
321 ORDER BY id DESC
322 ) as row_number
323 FROM channel_messages
324 WHERE
325 channel_id in ({values})
326 ) AS messages
327 WHERE
328 row_number = 1
329 "#,
330 );
331
332 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
333 let last_messages = channel_message::Model::find_by_statement(stmt)
334 .all(&*tx)
335 .await?;
336
337 let mut changes = Vec::new();
338 for last_message in last_messages {
339 if let Some(observed_message) =
340 observed_messages_by_channel_id.get(&last_message.channel_id)
341 {
342 if observed_message.channel_message_id == last_message.id {
343 continue;
344 }
345 }
346 changes.push(proto::UnseenChannelMessage {
347 channel_id: last_message.channel_id.to_proto(),
348 message_id: last_message.id.to_proto(),
349 });
350 }
351
352 Ok(changes)
353 }
354
355 pub async fn remove_channel_message(
356 &self,
357 channel_id: ChannelId,
358 message_id: MessageId,
359 user_id: UserId,
360 ) -> Result<Vec<ConnectionId>> {
361 self.transaction(|tx| async move {
362 let mut rows = channel_chat_participant::Entity::find()
363 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
364 .stream(&*tx)
365 .await?;
366
367 let mut is_participant = false;
368 let mut participant_connection_ids = Vec::new();
369 while let Some(row) = rows.next().await {
370 let row = row?;
371 if row.user_id == user_id {
372 is_participant = true;
373 }
374 participant_connection_ids.push(row.connection());
375 }
376 drop(rows);
377
378 if !is_participant {
379 Err(anyhow!("not a chat participant"))?;
380 }
381
382 let result = channel_message::Entity::delete_by_id(message_id)
383 .filter(channel_message::Column::SenderId.eq(user_id))
384 .exec(&*tx)
385 .await?;
386 if result.rows_affected == 0 {
387 Err(anyhow!("no such message"))?;
388 }
389
390 Ok(participant_connection_ids)
391 })
392 .await
393 }
394}