1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 /// Inserts a record representing a user joining the chat for a given channel.
8 pub async fn join_channel_chat(
9 &self,
10 channel_id: ChannelId,
11 connection_id: ConnectionId,
12 user_id: UserId,
13 ) -> Result<()> {
14 self.transaction(|tx| async move {
15 let channel = self.get_channel_internal(channel_id, &tx).await?;
16 self.check_user_is_channel_participant(&channel, user_id, &tx)
17 .await?;
18 channel_chat_participant::ActiveModel {
19 id: ActiveValue::NotSet,
20 channel_id: ActiveValue::Set(channel_id),
21 user_id: ActiveValue::Set(user_id),
22 connection_id: ActiveValue::Set(connection_id.id as i32),
23 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
24 }
25 .insert(&*tx)
26 .await?;
27 Ok(())
28 })
29 .await
30 }
31
32 /// Removes `channel_chat_participant` records associated with the given connection ID.
33 pub async fn channel_chat_connection_lost(
34 &self,
35 connection_id: ConnectionId,
36 tx: &DatabaseTransaction,
37 ) -> Result<()> {
38 channel_chat_participant::Entity::delete_many()
39 .filter(
40 Condition::all()
41 .add(
42 channel_chat_participant::Column::ConnectionServerId
43 .eq(connection_id.owner_id),
44 )
45 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
46 )
47 .exec(tx)
48 .await?;
49 Ok(())
50 }
51
52 /// Removes `channel_chat_participant` records associated with the given user ID so they
53 /// will no longer get chat notifications.
54 pub async fn leave_channel_chat(
55 &self,
56 channel_id: ChannelId,
57 connection_id: ConnectionId,
58 _user_id: UserId,
59 ) -> Result<()> {
60 self.transaction(|tx| async move {
61 channel_chat_participant::Entity::delete_many()
62 .filter(
63 Condition::all()
64 .add(
65 channel_chat_participant::Column::ConnectionServerId
66 .eq(connection_id.owner_id),
67 )
68 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
69 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
70 )
71 .exec(&*tx)
72 .await?;
73
74 Ok(())
75 })
76 .await
77 }
78
79 /// Retrieves the messages in the specified channel.
80 ///
81 /// Use `before_message_id` to paginate through the channel's messages.
82 pub async fn get_channel_messages(
83 &self,
84 channel_id: ChannelId,
85 user_id: UserId,
86 count: usize,
87 before_message_id: Option<MessageId>,
88 ) -> Result<Vec<proto::ChannelMessage>> {
89 self.transaction(|tx| async move {
90 let channel = self.get_channel_internal(channel_id, &tx).await?;
91 self.check_user_is_channel_participant(&channel, user_id, &tx)
92 .await?;
93
94 let mut condition =
95 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
96
97 if let Some(before_message_id) = before_message_id {
98 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
99 }
100
101 let rows = channel_message::Entity::find()
102 .filter(condition)
103 .order_by_desc(channel_message::Column::Id)
104 .limit(count as u64)
105 .all(&*tx)
106 .await?;
107
108 self.load_channel_messages(rows, &tx).await
109 })
110 .await
111 }
112
113 /// Returns the channel messages with the given IDs.
114 pub async fn get_channel_messages_by_id(
115 &self,
116 user_id: UserId,
117 message_ids: &[MessageId],
118 ) -> Result<Vec<proto::ChannelMessage>> {
119 self.transaction(|tx| async move {
120 let rows = channel_message::Entity::find()
121 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122 .order_by_desc(channel_message::Column::Id)
123 .all(&*tx)
124 .await?;
125
126 let mut channels = HashMap::<ChannelId, channel::Model>::default();
127 for row in &rows {
128 channels.insert(
129 row.channel_id,
130 self.get_channel_internal(row.channel_id, &tx).await?,
131 );
132 }
133
134 for (_, channel) in channels {
135 self.check_user_is_channel_participant(&channel, user_id, &tx)
136 .await?;
137 }
138
139 let messages = self.load_channel_messages(rows, &tx).await?;
140 Ok(messages)
141 })
142 .await
143 }
144
145 async fn load_channel_messages(
146 &self,
147 rows: Vec<channel_message::Model>,
148 tx: &DatabaseTransaction,
149 ) -> Result<Vec<proto::ChannelMessage>> {
150 let mut messages = rows
151 .into_iter()
152 .map(|row| {
153 let nonce = row.nonce.as_u64_pair();
154 proto::ChannelMessage {
155 id: row.id.to_proto(),
156 sender_id: row.sender_id.to_proto(),
157 body: row.body,
158 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159 mentions: vec![],
160 nonce: Some(proto::Nonce {
161 upper_half: nonce.0,
162 lower_half: nonce.1,
163 }),
164 reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
165 edited_at: row
166 .edited_at
167 .map(|t| t.assume_utc().unix_timestamp() as u64),
168 }
169 })
170 .collect::<Vec<_>>();
171 messages.reverse();
172
173 let mut mentions = channel_message_mention::Entity::find()
174 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
175 .order_by_asc(channel_message_mention::Column::MessageId)
176 .order_by_asc(channel_message_mention::Column::StartOffset)
177 .stream(tx)
178 .await?;
179
180 let mut message_ix = 0;
181 while let Some(mention) = mentions.next().await {
182 let mention = mention?;
183 let message_id = mention.message_id.to_proto();
184 while let Some(message) = messages.get_mut(message_ix) {
185 if message.id < message_id {
186 message_ix += 1;
187 } else {
188 if message.id == message_id {
189 message.mentions.push(proto::ChatMention {
190 range: Some(proto::Range {
191 start: mention.start_offset as u64,
192 end: mention.end_offset as u64,
193 }),
194 user_id: mention.user_id.to_proto(),
195 });
196 }
197 break;
198 }
199 }
200 }
201
202 Ok(messages)
203 }
204
205 fn format_mentions_to_entities(
206 &self,
207 message_id: MessageId,
208 body: &str,
209 mentions: &[proto::ChatMention],
210 ) -> Result<Vec<tables::channel_message_mention::ActiveModel>> {
211 Ok(mentions
212 .iter()
213 .filter_map(|mention| {
214 let range = mention.range.as_ref()?;
215 if !body.is_char_boundary(range.start as usize)
216 || !body.is_char_boundary(range.end as usize)
217 {
218 return None;
219 }
220 Some(channel_message_mention::ActiveModel {
221 message_id: ActiveValue::Set(message_id),
222 start_offset: ActiveValue::Set(range.start as i32),
223 end_offset: ActiveValue::Set(range.end as i32),
224 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
225 })
226 })
227 .collect::<Vec<_>>())
228 }
229
230 /// Creates a new channel message.
231 #[allow(clippy::too_many_arguments)]
232 pub async fn create_channel_message(
233 &self,
234 channel_id: ChannelId,
235 user_id: UserId,
236 body: &str,
237 mentions: &[proto::ChatMention],
238 timestamp: OffsetDateTime,
239 nonce: u128,
240 reply_to_message_id: Option<MessageId>,
241 ) -> Result<CreatedChannelMessage> {
242 self.transaction(|tx| async move {
243 let channel = self.get_channel_internal(channel_id, &tx).await?;
244 self.check_user_is_channel_participant(&channel, user_id, &tx)
245 .await?;
246
247 let mut rows = channel_chat_participant::Entity::find()
248 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
249 .stream(&*tx)
250 .await?;
251
252 let mut is_participant = false;
253 let mut participant_connection_ids = Vec::new();
254 let mut participant_user_ids = Vec::new();
255 while let Some(row) = rows.next().await {
256 let row = row?;
257 if row.user_id == user_id {
258 is_participant = true;
259 }
260 participant_user_ids.push(row.user_id);
261 participant_connection_ids.push(row.connection());
262 }
263 drop(rows);
264
265 if !is_participant {
266 Err(anyhow!("not a chat participant"))?;
267 }
268
269 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
270 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
271
272 let result = channel_message::Entity::insert(channel_message::ActiveModel {
273 channel_id: ActiveValue::Set(channel_id),
274 sender_id: ActiveValue::Set(user_id),
275 body: ActiveValue::Set(body.to_string()),
276 sent_at: ActiveValue::Set(timestamp),
277 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
278 id: ActiveValue::NotSet,
279 reply_to_message_id: ActiveValue::Set(reply_to_message_id),
280 edited_at: ActiveValue::NotSet,
281 })
282 .on_conflict(
283 OnConflict::columns([
284 channel_message::Column::SenderId,
285 channel_message::Column::Nonce,
286 ])
287 .do_nothing()
288 .to_owned(),
289 )
290 .do_nothing()
291 .exec(&*tx)
292 .await?;
293
294 let message_id;
295 let mut notifications = Vec::new();
296 match result {
297 TryInsertResult::Inserted(result) => {
298 message_id = result.last_insert_id;
299 let mentioned_user_ids =
300 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
301
302 let mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
303 if !mentions.is_empty() {
304 channel_message_mention::Entity::insert_many(mentions)
305 .exec(&*tx)
306 .await?;
307 }
308
309 for mentioned_user in mentioned_user_ids {
310 notifications.extend(
311 self.create_notification(
312 UserId::from_proto(mentioned_user),
313 rpc::Notification::ChannelMessageMention {
314 message_id: message_id.to_proto(),
315 sender_id: user_id.to_proto(),
316 channel_id: channel_id.to_proto(),
317 },
318 false,
319 &tx,
320 )
321 .await?,
322 );
323 }
324
325 self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
326 .await?;
327 }
328 _ => {
329 message_id = channel_message::Entity::find()
330 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
331 .one(&*tx)
332 .await?
333 .ok_or_else(|| anyhow!("failed to insert message"))?
334 .id;
335 }
336 }
337
338 let mut channel_members = self.get_channel_participants(&channel, &tx).await?;
339 channel_members.retain(|member| !participant_user_ids.contains(member));
340
341 Ok(CreatedChannelMessage {
342 message_id,
343 participant_connection_ids,
344 channel_members,
345 notifications,
346 })
347 })
348 .await
349 }
350
351 pub async fn observe_channel_message(
352 &self,
353 channel_id: ChannelId,
354 user_id: UserId,
355 message_id: MessageId,
356 ) -> Result<NotificationBatch> {
357 self.transaction(|tx| async move {
358 self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
359 .await?;
360 let mut batch = NotificationBatch::default();
361 batch.extend(
362 self.mark_notification_as_read(
363 user_id,
364 &Notification::ChannelMessageMention {
365 message_id: message_id.to_proto(),
366 sender_id: Default::default(),
367 channel_id: Default::default(),
368 },
369 &tx,
370 )
371 .await?,
372 );
373 Ok(batch)
374 })
375 .await
376 }
377
378 async fn observe_channel_message_internal(
379 &self,
380 channel_id: ChannelId,
381 user_id: UserId,
382 message_id: MessageId,
383 tx: &DatabaseTransaction,
384 ) -> Result<()> {
385 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
386 user_id: ActiveValue::Set(user_id),
387 channel_id: ActiveValue::Set(channel_id),
388 channel_message_id: ActiveValue::Set(message_id),
389 })
390 .on_conflict(
391 OnConflict::columns([
392 observed_channel_messages::Column::ChannelId,
393 observed_channel_messages::Column::UserId,
394 ])
395 .update_column(observed_channel_messages::Column::ChannelMessageId)
396 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
397 .to_owned(),
398 )
399 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
400 .exec_without_returning(tx)
401 .await?;
402 Ok(())
403 }
404
405 pub async fn observed_channel_messages(
406 &self,
407 channel_ids: &[ChannelId],
408 user_id: UserId,
409 tx: &DatabaseTransaction,
410 ) -> Result<Vec<proto::ChannelMessageId>> {
411 let rows = observed_channel_messages::Entity::find()
412 .filter(observed_channel_messages::Column::UserId.eq(user_id))
413 .filter(
414 observed_channel_messages::Column::ChannelId
415 .is_in(channel_ids.iter().map(|id| id.0)),
416 )
417 .all(tx)
418 .await?;
419
420 Ok(rows
421 .into_iter()
422 .map(|message| proto::ChannelMessageId {
423 channel_id: message.channel_id.to_proto(),
424 message_id: message.channel_message_id.to_proto(),
425 })
426 .collect())
427 }
428
429 pub async fn latest_channel_messages(
430 &self,
431 channel_ids: &[ChannelId],
432 tx: &DatabaseTransaction,
433 ) -> Result<Vec<proto::ChannelMessageId>> {
434 let mut values = String::new();
435 for id in channel_ids {
436 if !values.is_empty() {
437 values.push_str(", ");
438 }
439 write!(&mut values, "({})", id).unwrap();
440 }
441
442 if values.is_empty() {
443 return Ok(Vec::default());
444 }
445
446 let sql = format!(
447 r#"
448 SELECT
449 *
450 FROM (
451 SELECT
452 *,
453 row_number() OVER (
454 PARTITION BY channel_id
455 ORDER BY id DESC
456 ) as row_number
457 FROM channel_messages
458 WHERE
459 channel_id in ({values})
460 ) AS messages
461 WHERE
462 row_number = 1
463 "#,
464 );
465
466 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
467 let mut last_messages = channel_message::Model::find_by_statement(stmt)
468 .stream(tx)
469 .await?;
470
471 let mut results = Vec::new();
472 while let Some(result) = last_messages.next().await {
473 let message = result?;
474 results.push(proto::ChannelMessageId {
475 channel_id: message.channel_id.to_proto(),
476 message_id: message.id.to_proto(),
477 });
478 }
479
480 Ok(results)
481 }
482
483 /// Removes the channel message with the given ID.
484 pub async fn remove_channel_message(
485 &self,
486 channel_id: ChannelId,
487 message_id: MessageId,
488 user_id: UserId,
489 ) -> Result<Vec<ConnectionId>> {
490 self.transaction(|tx| async move {
491 let mut rows = channel_chat_participant::Entity::find()
492 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
493 .stream(&*tx)
494 .await?;
495
496 let mut is_participant = false;
497 let mut participant_connection_ids = Vec::new();
498 while let Some(row) = rows.next().await {
499 let row = row?;
500 if row.user_id == user_id {
501 is_participant = true;
502 }
503 participant_connection_ids.push(row.connection());
504 }
505 drop(rows);
506
507 if !is_participant {
508 Err(anyhow!("not a chat participant"))?;
509 }
510
511 let result = channel_message::Entity::delete_by_id(message_id)
512 .filter(channel_message::Column::SenderId.eq(user_id))
513 .exec(&*tx)
514 .await?;
515
516 if result.rows_affected == 0 {
517 let channel = self.get_channel_internal(channel_id, &tx).await?;
518 if self
519 .check_user_is_channel_admin(&channel, user_id, &tx)
520 .await
521 .is_ok()
522 {
523 let result = channel_message::Entity::delete_by_id(message_id)
524 .exec(&*tx)
525 .await?;
526 if result.rows_affected == 0 {
527 Err(anyhow!("no such message"))?;
528 }
529 } else {
530 Err(anyhow!("operation could not be completed"))?;
531 }
532 }
533
534 Ok(participant_connection_ids)
535 })
536 .await
537 }
538
539 /// Updates the channel message with the given ID, body and timestamp(edited_at).
540 pub async fn update_channel_message(
541 &self,
542 channel_id: ChannelId,
543 message_id: MessageId,
544 user_id: UserId,
545 body: &str,
546 mentions: &[proto::ChatMention],
547 edited_at: OffsetDateTime,
548 ) -> Result<UpdatedChannelMessage> {
549 self.transaction(|tx| async move {
550 let channel = self.get_channel_internal(channel_id, &tx).await?;
551 self.check_user_is_channel_participant(&channel, user_id, &tx)
552 .await?;
553
554 let mut rows = channel_chat_participant::Entity::find()
555 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
556 .stream(&*tx)
557 .await?;
558
559 let mut is_participant = false;
560 let mut participant_connection_ids = Vec::new();
561 let mut participant_user_ids = Vec::new();
562 while let Some(row) = rows.next().await {
563 let row = row?;
564 if row.user_id == user_id {
565 is_participant = true;
566 }
567 participant_user_ids.push(row.user_id);
568 participant_connection_ids.push(row.connection());
569 }
570 drop(rows);
571
572 if !is_participant {
573 Err(anyhow!("not a chat participant"))?;
574 }
575
576 let channel_message = channel_message::Entity::find_by_id(message_id)
577 .filter(channel_message::Column::SenderId.eq(user_id))
578 .one(&*tx)
579 .await?;
580
581 let Some(channel_message) = channel_message else {
582 Err(anyhow!("Channel message not found"))?
583 };
584
585 let edited_at = edited_at.to_offset(time::UtcOffset::UTC);
586 let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time());
587
588 let updated_message = channel_message::ActiveModel {
589 body: ActiveValue::Set(body.to_string()),
590 edited_at: ActiveValue::Set(Some(edited_at)),
591 reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id),
592 id: ActiveValue::Unchanged(message_id),
593 channel_id: ActiveValue::Unchanged(channel_id),
594 sender_id: ActiveValue::Unchanged(user_id),
595 sent_at: ActiveValue::Unchanged(channel_message.sent_at),
596 nonce: ActiveValue::Unchanged(channel_message.nonce),
597 };
598
599 let result = channel_message::Entity::update_many()
600 .set(updated_message)
601 .filter(channel_message::Column::Id.eq(message_id))
602 .filter(channel_message::Column::SenderId.eq(user_id))
603 .exec(&*tx)
604 .await?;
605 if result.rows_affected == 0 {
606 return Err(anyhow!(
607 "Attempted to edit a message (id: {message_id}) which does not exist anymore."
608 ))?;
609 }
610
611 // we have to fetch the old mentions,
612 // so we don't send a notification when the message has been edited that you are mentioned in
613 let old_mentions = channel_message_mention::Entity::find()
614 .filter(channel_message_mention::Column::MessageId.eq(message_id))
615 .all(&*tx)
616 .await?;
617
618 // remove all existing mentions
619 channel_message_mention::Entity::delete_many()
620 .filter(channel_message_mention::Column::MessageId.eq(message_id))
621 .exec(&*tx)
622 .await?;
623
624 let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
625 if !new_mentions.is_empty() {
626 // insert new mentions
627 channel_message_mention::Entity::insert_many(new_mentions)
628 .exec(&*tx)
629 .await?;
630 }
631
632 let mut mentioned_user_ids = mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
633 // Filter out users that were mentioned before
634 for mention in old_mentions {
635 mentioned_user_ids.remove(&mention.user_id.to_proto());
636 }
637
638 let mut notifications = Vec::new();
639 for mentioned_user in mentioned_user_ids {
640 notifications.extend(
641 self.create_notification(
642 UserId::from_proto(mentioned_user),
643 rpc::Notification::ChannelMessageMention {
644 message_id: message_id.to_proto(),
645 sender_id: user_id.to_proto(),
646 channel_id: channel_id.to_proto(),
647 },
648 false,
649 &tx,
650 )
651 .await?,
652 );
653 }
654
655 Ok(UpdatedChannelMessage {
656 message_id,
657 participant_connection_ids,
658 notifications,
659 reply_to_message_id: channel_message.reply_to_message_id,
660 timestamp: channel_message.sent_at,
661 })
662 })
663 .await
664 }
665}