messages.rs

  1use super::*;
  2use rpc::Notification;
  3use sea_orm::TryInsertResult;
  4use time::OffsetDateTime;
  5
  6impl Database {
  7    /// Inserts a record representing a user joining the chat for a given channel.
  8    pub async fn join_channel_chat(
  9        &self,
 10        channel_id: ChannelId,
 11        connection_id: ConnectionId,
 12        user_id: UserId,
 13    ) -> Result<()> {
 14        self.transaction(|tx| async move {
 15            let channel = self.get_channel_internal(channel_id, &tx).await?;
 16            self.check_user_is_channel_participant(&channel, user_id, &tx)
 17                .await?;
 18            channel_chat_participant::ActiveModel {
 19                id: ActiveValue::NotSet,
 20                channel_id: ActiveValue::Set(channel_id),
 21                user_id: ActiveValue::Set(user_id),
 22                connection_id: ActiveValue::Set(connection_id.id as i32),
 23                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 24            }
 25            .insert(&*tx)
 26            .await?;
 27            Ok(())
 28        })
 29        .await
 30    }
 31
 32    /// Removes `channel_chat_participant` records associated with the given connection ID.
 33    pub async fn channel_chat_connection_lost(
 34        &self,
 35        connection_id: ConnectionId,
 36        tx: &DatabaseTransaction,
 37    ) -> Result<()> {
 38        channel_chat_participant::Entity::delete_many()
 39            .filter(
 40                Condition::all()
 41                    .add(
 42                        channel_chat_participant::Column::ConnectionServerId
 43                            .eq(connection_id.owner_id),
 44                    )
 45                    .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
 46            )
 47            .exec(tx)
 48            .await?;
 49        Ok(())
 50    }
 51
 52    /// Removes `channel_chat_participant` records associated with the given user ID so they
 53    /// will no longer get chat notifications.
 54    pub async fn leave_channel_chat(
 55        &self,
 56        channel_id: ChannelId,
 57        connection_id: ConnectionId,
 58        _user_id: UserId,
 59    ) -> Result<()> {
 60        self.transaction(|tx| async move {
 61            channel_chat_participant::Entity::delete_many()
 62                .filter(
 63                    Condition::all()
 64                        .add(
 65                            channel_chat_participant::Column::ConnectionServerId
 66                                .eq(connection_id.owner_id),
 67                        )
 68                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
 69                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
 70                )
 71                .exec(&*tx)
 72                .await?;
 73
 74            Ok(())
 75        })
 76        .await
 77    }
 78
 79    /// Retrieves the messages in the specified channel.
 80    ///
 81    /// Use `before_message_id` to paginate through the channel's messages.
 82    pub async fn get_channel_messages(
 83        &self,
 84        channel_id: ChannelId,
 85        user_id: UserId,
 86        count: usize,
 87        before_message_id: Option<MessageId>,
 88    ) -> Result<Vec<proto::ChannelMessage>> {
 89        self.transaction(|tx| async move {
 90            let channel = self.get_channel_internal(channel_id, &tx).await?;
 91            self.check_user_is_channel_participant(&channel, user_id, &tx)
 92                .await?;
 93
 94            let mut condition =
 95                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
 96
 97            if let Some(before_message_id) = before_message_id {
 98                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
 99            }
100
101            let rows = channel_message::Entity::find()
102                .filter(condition)
103                .order_by_desc(channel_message::Column::Id)
104                .limit(count as u64)
105                .all(&*tx)
106                .await?;
107
108            self.load_channel_messages(rows, &tx).await
109        })
110        .await
111    }
112
113    /// Returns the channel messages with the given IDs.
114    pub async fn get_channel_messages_by_id(
115        &self,
116        user_id: UserId,
117        message_ids: &[MessageId],
118    ) -> Result<Vec<proto::ChannelMessage>> {
119        self.transaction(|tx| async move {
120            let rows = channel_message::Entity::find()
121                .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122                .order_by_desc(channel_message::Column::Id)
123                .all(&*tx)
124                .await?;
125
126            let mut channels = HashMap::<ChannelId, channel::Model>::default();
127            for row in &rows {
128                channels.insert(
129                    row.channel_id,
130                    self.get_channel_internal(row.channel_id, &tx).await?,
131                );
132            }
133
134            for (_, channel) in channels {
135                self.check_user_is_channel_participant(&channel, user_id, &tx)
136                    .await?;
137            }
138
139            let messages = self.load_channel_messages(rows, &tx).await?;
140            Ok(messages)
141        })
142        .await
143    }
144
145    async fn load_channel_messages(
146        &self,
147        rows: Vec<channel_message::Model>,
148        tx: &DatabaseTransaction,
149    ) -> Result<Vec<proto::ChannelMessage>> {
150        let mut messages = rows
151            .into_iter()
152            .map(|row| {
153                let nonce = row.nonce.as_u64_pair();
154                proto::ChannelMessage {
155                    id: row.id.to_proto(),
156                    sender_id: row.sender_id.to_proto(),
157                    body: row.body,
158                    timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159                    mentions: vec![],
160                    nonce: Some(proto::Nonce {
161                        upper_half: nonce.0,
162                        lower_half: nonce.1,
163                    }),
164                    reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
165                    edited_at: row
166                        .edited_at
167                        .map(|t| t.assume_utc().unix_timestamp() as u64),
168                }
169            })
170            .collect::<Vec<_>>();
171        messages.reverse();
172
173        let mut mentions = channel_message_mention::Entity::find()
174            .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
175            .order_by_asc(channel_message_mention::Column::MessageId)
176            .order_by_asc(channel_message_mention::Column::StartOffset)
177            .stream(tx)
178            .await?;
179
180        let mut message_ix = 0;
181        while let Some(mention) = mentions.next().await {
182            let mention = mention?;
183            let message_id = mention.message_id.to_proto();
184            while let Some(message) = messages.get_mut(message_ix) {
185                if message.id < message_id {
186                    message_ix += 1;
187                } else {
188                    if message.id == message_id {
189                        message.mentions.push(proto::ChatMention {
190                            range: Some(proto::Range {
191                                start: mention.start_offset as u64,
192                                end: mention.end_offset as u64,
193                            }),
194                            user_id: mention.user_id.to_proto(),
195                        });
196                    }
197                    break;
198                }
199            }
200        }
201
202        Ok(messages)
203    }
204
205    fn format_mentions_to_entities(
206        &self,
207        message_id: MessageId,
208        body: &str,
209        mentions: &[proto::ChatMention],
210    ) -> Result<Vec<tables::channel_message_mention::ActiveModel>> {
211        Ok(mentions
212            .iter()
213            .filter_map(|mention| {
214                let range = mention.range.as_ref()?;
215                if !body.is_char_boundary(range.start as usize)
216                    || !body.is_char_boundary(range.end as usize)
217                {
218                    return None;
219                }
220                Some(channel_message_mention::ActiveModel {
221                    message_id: ActiveValue::Set(message_id),
222                    start_offset: ActiveValue::Set(range.start as i32),
223                    end_offset: ActiveValue::Set(range.end as i32),
224                    user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
225                })
226            })
227            .collect::<Vec<_>>())
228    }
229
230    /// Creates a new channel message.
231    #[allow(clippy::too_many_arguments)]
232    pub async fn create_channel_message(
233        &self,
234        channel_id: ChannelId,
235        user_id: UserId,
236        body: &str,
237        mentions: &[proto::ChatMention],
238        timestamp: OffsetDateTime,
239        nonce: u128,
240        reply_to_message_id: Option<MessageId>,
241    ) -> Result<CreatedChannelMessage> {
242        self.transaction(|tx| async move {
243            let channel = self.get_channel_internal(channel_id, &tx).await?;
244            self.check_user_is_channel_participant(&channel, user_id, &tx)
245                .await?;
246
247            let mut rows = channel_chat_participant::Entity::find()
248                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
249                .stream(&*tx)
250                .await?;
251
252            let mut is_participant = false;
253            let mut participant_connection_ids = Vec::new();
254            let mut participant_user_ids = Vec::new();
255            while let Some(row) = rows.next().await {
256                let row = row?;
257                if row.user_id == user_id {
258                    is_participant = true;
259                }
260                participant_user_ids.push(row.user_id);
261                participant_connection_ids.push(row.connection());
262            }
263            drop(rows);
264
265            if !is_participant {
266                Err(anyhow!("not a chat participant"))?;
267            }
268
269            let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
270            let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
271
272            let result = channel_message::Entity::insert(channel_message::ActiveModel {
273                channel_id: ActiveValue::Set(channel_id),
274                sender_id: ActiveValue::Set(user_id),
275                body: ActiveValue::Set(body.to_string()),
276                sent_at: ActiveValue::Set(timestamp),
277                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
278                id: ActiveValue::NotSet,
279                reply_to_message_id: ActiveValue::Set(reply_to_message_id),
280                edited_at: ActiveValue::NotSet,
281            })
282            .on_conflict(
283                OnConflict::columns([
284                    channel_message::Column::SenderId,
285                    channel_message::Column::Nonce,
286                ])
287                .do_nothing()
288                .to_owned(),
289            )
290            .do_nothing()
291            .exec(&*tx)
292            .await?;
293
294            let message_id;
295            let mut notifications = Vec::new();
296            match result {
297                TryInsertResult::Inserted(result) => {
298                    message_id = result.last_insert_id;
299                    let mentioned_user_ids =
300                        mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
301
302                    let mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
303                    if !mentions.is_empty() {
304                        channel_message_mention::Entity::insert_many(mentions)
305                            .exec(&*tx)
306                            .await?;
307                    }
308
309                    for mentioned_user in mentioned_user_ids {
310                        notifications.extend(
311                            self.create_notification(
312                                UserId::from_proto(mentioned_user),
313                                rpc::Notification::ChannelMessageMention {
314                                    message_id: message_id.to_proto(),
315                                    sender_id: user_id.to_proto(),
316                                    channel_id: channel_id.to_proto(),
317                                },
318                                false,
319                                &tx,
320                            )
321                            .await?,
322                        );
323                    }
324
325                    self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
326                        .await?;
327                }
328                _ => {
329                    message_id = channel_message::Entity::find()
330                        .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
331                        .one(&*tx)
332                        .await?
333                        .ok_or_else(|| anyhow!("failed to insert message"))?
334                        .id;
335                }
336            }
337
338            let mut channel_members = self.get_channel_participants(&channel, &tx).await?;
339            channel_members.retain(|member| !participant_user_ids.contains(member));
340
341            Ok(CreatedChannelMessage {
342                message_id,
343                participant_connection_ids,
344                channel_members,
345                notifications,
346            })
347        })
348        .await
349    }
350
351    pub async fn observe_channel_message(
352        &self,
353        channel_id: ChannelId,
354        user_id: UserId,
355        message_id: MessageId,
356    ) -> Result<NotificationBatch> {
357        self.transaction(|tx| async move {
358            self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
359                .await?;
360            let mut batch = NotificationBatch::default();
361            batch.extend(
362                self.mark_notification_as_read(
363                    user_id,
364                    &Notification::ChannelMessageMention {
365                        message_id: message_id.to_proto(),
366                        sender_id: Default::default(),
367                        channel_id: Default::default(),
368                    },
369                    &tx,
370                )
371                .await?,
372            );
373            Ok(batch)
374        })
375        .await
376    }
377
378    async fn observe_channel_message_internal(
379        &self,
380        channel_id: ChannelId,
381        user_id: UserId,
382        message_id: MessageId,
383        tx: &DatabaseTransaction,
384    ) -> Result<()> {
385        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
386            user_id: ActiveValue::Set(user_id),
387            channel_id: ActiveValue::Set(channel_id),
388            channel_message_id: ActiveValue::Set(message_id),
389        })
390        .on_conflict(
391            OnConflict::columns([
392                observed_channel_messages::Column::ChannelId,
393                observed_channel_messages::Column::UserId,
394            ])
395            .update_column(observed_channel_messages::Column::ChannelMessageId)
396            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
397            .to_owned(),
398        )
399        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
400        .exec_without_returning(tx)
401        .await?;
402        Ok(())
403    }
404
405    pub async fn observed_channel_messages(
406        &self,
407        channel_ids: &[ChannelId],
408        user_id: UserId,
409        tx: &DatabaseTransaction,
410    ) -> Result<Vec<proto::ChannelMessageId>> {
411        let rows = observed_channel_messages::Entity::find()
412            .filter(observed_channel_messages::Column::UserId.eq(user_id))
413            .filter(
414                observed_channel_messages::Column::ChannelId
415                    .is_in(channel_ids.iter().map(|id| id.0)),
416            )
417            .all(tx)
418            .await?;
419
420        Ok(rows
421            .into_iter()
422            .map(|message| proto::ChannelMessageId {
423                channel_id: message.channel_id.to_proto(),
424                message_id: message.channel_message_id.to_proto(),
425            })
426            .collect())
427    }
428
429    pub async fn latest_channel_messages(
430        &self,
431        channel_ids: &[ChannelId],
432        tx: &DatabaseTransaction,
433    ) -> Result<Vec<proto::ChannelMessageId>> {
434        let mut values = String::new();
435        for id in channel_ids {
436            if !values.is_empty() {
437                values.push_str(", ");
438            }
439            write!(&mut values, "({})", id).unwrap();
440        }
441
442        if values.is_empty() {
443            return Ok(Vec::default());
444        }
445
446        let sql = format!(
447            r#"
448            SELECT
449                *
450            FROM (
451                SELECT
452                    *,
453                    row_number() OVER (
454                        PARTITION BY channel_id
455                        ORDER BY id DESC
456                    ) as row_number
457                FROM channel_messages
458                WHERE
459                    channel_id in ({values})
460            ) AS messages
461            WHERE
462                row_number = 1
463            "#,
464        );
465
466        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
467        let mut last_messages = channel_message::Model::find_by_statement(stmt)
468            .stream(tx)
469            .await?;
470
471        let mut results = Vec::new();
472        while let Some(result) = last_messages.next().await {
473            let message = result?;
474            results.push(proto::ChannelMessageId {
475                channel_id: message.channel_id.to_proto(),
476                message_id: message.id.to_proto(),
477            });
478        }
479
480        Ok(results)
481    }
482
483    /// Removes the channel message with the given ID.
484    pub async fn remove_channel_message(
485        &self,
486        channel_id: ChannelId,
487        message_id: MessageId,
488        user_id: UserId,
489    ) -> Result<Vec<ConnectionId>> {
490        self.transaction(|tx| async move {
491            let mut rows = channel_chat_participant::Entity::find()
492                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
493                .stream(&*tx)
494                .await?;
495
496            let mut is_participant = false;
497            let mut participant_connection_ids = Vec::new();
498            while let Some(row) = rows.next().await {
499                let row = row?;
500                if row.user_id == user_id {
501                    is_participant = true;
502                }
503                participant_connection_ids.push(row.connection());
504            }
505            drop(rows);
506
507            if !is_participant {
508                Err(anyhow!("not a chat participant"))?;
509            }
510
511            let result = channel_message::Entity::delete_by_id(message_id)
512                .filter(channel_message::Column::SenderId.eq(user_id))
513                .exec(&*tx)
514                .await?;
515
516            if result.rows_affected == 0 {
517                let channel = self.get_channel_internal(channel_id, &tx).await?;
518                if self
519                    .check_user_is_channel_admin(&channel, user_id, &tx)
520                    .await
521                    .is_ok()
522                {
523                    let result = channel_message::Entity::delete_by_id(message_id)
524                        .exec(&*tx)
525                        .await?;
526                    if result.rows_affected == 0 {
527                        Err(anyhow!("no such message"))?;
528                    }
529                } else {
530                    Err(anyhow!("operation could not be completed"))?;
531                }
532            }
533
534            Ok(participant_connection_ids)
535        })
536        .await
537    }
538
539    /// Updates the channel message with the given ID, body and timestamp(edited_at).
540    pub async fn update_channel_message(
541        &self,
542        channel_id: ChannelId,
543        message_id: MessageId,
544        user_id: UserId,
545        body: &str,
546        mentions: &[proto::ChatMention],
547        edited_at: OffsetDateTime,
548    ) -> Result<UpdatedChannelMessage> {
549        self.transaction(|tx| async move {
550            let channel = self.get_channel_internal(channel_id, &tx).await?;
551            self.check_user_is_channel_participant(&channel, user_id, &tx)
552                .await?;
553
554            let mut rows = channel_chat_participant::Entity::find()
555                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
556                .stream(&*tx)
557                .await?;
558
559            let mut is_participant = false;
560            let mut participant_connection_ids = Vec::new();
561            let mut participant_user_ids = Vec::new();
562            while let Some(row) = rows.next().await {
563                let row = row?;
564                if row.user_id == user_id {
565                    is_participant = true;
566                }
567                participant_user_ids.push(row.user_id);
568                participant_connection_ids.push(row.connection());
569            }
570            drop(rows);
571
572            if !is_participant {
573                Err(anyhow!("not a chat participant"))?;
574            }
575
576            let channel_message = channel_message::Entity::find_by_id(message_id)
577                .filter(channel_message::Column::SenderId.eq(user_id))
578                .one(&*tx)
579                .await?;
580
581            let Some(channel_message) = channel_message else {
582                Err(anyhow!("Channel message not found"))?
583            };
584
585            let edited_at = edited_at.to_offset(time::UtcOffset::UTC);
586            let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time());
587
588            let updated_message = channel_message::ActiveModel {
589                body: ActiveValue::Set(body.to_string()),
590                edited_at: ActiveValue::Set(Some(edited_at)),
591                reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id),
592                id: ActiveValue::Unchanged(message_id),
593                channel_id: ActiveValue::Unchanged(channel_id),
594                sender_id: ActiveValue::Unchanged(user_id),
595                sent_at: ActiveValue::Unchanged(channel_message.sent_at),
596                nonce: ActiveValue::Unchanged(channel_message.nonce),
597            };
598
599            let result = channel_message::Entity::update_many()
600                .set(updated_message)
601                .filter(channel_message::Column::Id.eq(message_id))
602                .filter(channel_message::Column::SenderId.eq(user_id))
603                .exec(&*tx)
604                .await?;
605            if result.rows_affected == 0 {
606                return Err(anyhow!(
607                    "Attempted to edit a message (id: {message_id}) which does not exist anymore."
608                ))?;
609            }
610
611            // we have to fetch the old mentions,
612            // so we don't send a notification when the message has been edited that you are mentioned in
613            let old_mentions = channel_message_mention::Entity::find()
614                .filter(channel_message_mention::Column::MessageId.eq(message_id))
615                .all(&*tx)
616                .await?;
617
618            // remove all existing mentions
619            channel_message_mention::Entity::delete_many()
620                .filter(channel_message_mention::Column::MessageId.eq(message_id))
621                .exec(&*tx)
622                .await?;
623
624            let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
625            if !new_mentions.is_empty() {
626                // insert new mentions
627                channel_message_mention::Entity::insert_many(new_mentions)
628                    .exec(&*tx)
629                    .await?;
630            }
631
632            let mut mentioned_user_ids = mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
633            // Filter out users that were mentioned before
634            for mention in old_mentions {
635                mentioned_user_ids.remove(&mention.user_id.to_proto());
636            }
637
638            let mut notifications = Vec::new();
639            for mentioned_user in mentioned_user_ids {
640                notifications.extend(
641                    self.create_notification(
642                        UserId::from_proto(mentioned_user),
643                        rpc::Notification::ChannelMessageMention {
644                            message_id: message_id.to_proto(),
645                            sender_id: user_id.to_proto(),
646                            channel_id: channel_id.to_proto(),
647                        },
648                        false,
649                        &tx,
650                    )
651                    .await?,
652                );
653            }
654
655            Ok(UpdatedChannelMessage {
656                message_id,
657                participant_connection_ids,
658                notifications,
659                reply_to_message_id: channel_message.reply_to_message_id,
660                timestamp: channel_message.sent_at,
661            })
662        })
663        .await
664    }
665}