messages.rs

  1use super::*;
  2use rpc::Notification;
  3use sea_orm::TryInsertResult;
  4use time::OffsetDateTime;
  5
  6impl Database {
  7    /// Inserts a record representing a user joining the chat for a given channel.
  8    pub async fn join_channel_chat(
  9        &self,
 10        channel_id: ChannelId,
 11        connection_id: ConnectionId,
 12        user_id: UserId,
 13    ) -> Result<()> {
 14        self.transaction(|tx| async move {
 15            let channel = self.get_channel_internal(channel_id, &*tx).await?;
 16            self.check_user_is_channel_participant(&channel, user_id, &*tx)
 17                .await?;
 18            channel_chat_participant::ActiveModel {
 19                id: ActiveValue::NotSet,
 20                channel_id: ActiveValue::Set(channel_id),
 21                user_id: ActiveValue::Set(user_id),
 22                connection_id: ActiveValue::Set(connection_id.id as i32),
 23                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 24            }
 25            .insert(&*tx)
 26            .await?;
 27            Ok(())
 28        })
 29        .await
 30    }
 31
 32    /// Removes `channel_chat_participant` records associated with the given connection ID.
 33    pub async fn channel_chat_connection_lost(
 34        &self,
 35        connection_id: ConnectionId,
 36        tx: &DatabaseTransaction,
 37    ) -> Result<()> {
 38        channel_chat_participant::Entity::delete_many()
 39            .filter(
 40                Condition::all()
 41                    .add(
 42                        channel_chat_participant::Column::ConnectionServerId
 43                            .eq(connection_id.owner_id),
 44                    )
 45                    .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
 46            )
 47            .exec(tx)
 48            .await?;
 49        Ok(())
 50    }
 51
 52    /// Removes `channel_chat_participant` records associated with the given user ID so they
 53    /// will no longer get chat notifications.
 54    pub async fn leave_channel_chat(
 55        &self,
 56        channel_id: ChannelId,
 57        connection_id: ConnectionId,
 58        _user_id: UserId,
 59    ) -> Result<()> {
 60        self.transaction(|tx| async move {
 61            channel_chat_participant::Entity::delete_many()
 62                .filter(
 63                    Condition::all()
 64                        .add(
 65                            channel_chat_participant::Column::ConnectionServerId
 66                                .eq(connection_id.owner_id),
 67                        )
 68                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
 69                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
 70                )
 71                .exec(&*tx)
 72                .await?;
 73
 74            Ok(())
 75        })
 76        .await
 77    }
 78
 79    /// Retrieves the messages in the specified channel.
 80    ///
 81    /// Use `before_message_id` to paginate through the channel's messages.
 82    pub async fn get_channel_messages(
 83        &self,
 84        channel_id: ChannelId,
 85        user_id: UserId,
 86        count: usize,
 87        before_message_id: Option<MessageId>,
 88    ) -> Result<Vec<proto::ChannelMessage>> {
 89        self.transaction(|tx| async move {
 90            let channel = self.get_channel_internal(channel_id, &*tx).await?;
 91            self.check_user_is_channel_participant(&channel, user_id, &*tx)
 92                .await?;
 93
 94            let mut condition =
 95                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
 96
 97            if let Some(before_message_id) = before_message_id {
 98                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
 99            }
100
101            let rows = channel_message::Entity::find()
102                .filter(condition)
103                .order_by_desc(channel_message::Column::Id)
104                .limit(count as u64)
105                .all(&*tx)
106                .await?;
107
108            self.load_channel_messages(rows, &*tx).await
109        })
110        .await
111    }
112
113    /// Returns the channel messages with the given IDs.
114    pub async fn get_channel_messages_by_id(
115        &self,
116        user_id: UserId,
117        message_ids: &[MessageId],
118    ) -> Result<Vec<proto::ChannelMessage>> {
119        self.transaction(|tx| async move {
120            let rows = channel_message::Entity::find()
121                .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122                .order_by_desc(channel_message::Column::Id)
123                .all(&*tx)
124                .await?;
125
126            let mut channels = HashMap::<ChannelId, channel::Model>::default();
127            for row in &rows {
128                channels.insert(
129                    row.channel_id,
130                    self.get_channel_internal(row.channel_id, &*tx).await?,
131                );
132            }
133
134            for (_, channel) in channels {
135                self.check_user_is_channel_participant(&channel, user_id, &*tx)
136                    .await?;
137            }
138
139            let messages = self.load_channel_messages(rows, &*tx).await?;
140            Ok(messages)
141        })
142        .await
143    }
144
145    async fn load_channel_messages(
146        &self,
147        rows: Vec<channel_message::Model>,
148        tx: &DatabaseTransaction,
149    ) -> Result<Vec<proto::ChannelMessage>> {
150        let mut messages = rows
151            .into_iter()
152            .map(|row| {
153                let nonce = row.nonce.as_u64_pair();
154                proto::ChannelMessage {
155                    id: row.id.to_proto(),
156                    sender_id: row.sender_id.to_proto(),
157                    body: row.body,
158                    timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159                    mentions: vec![],
160                    nonce: Some(proto::Nonce {
161                        upper_half: nonce.0,
162                        lower_half: nonce.1,
163                    }),
164                }
165            })
166            .collect::<Vec<_>>();
167        messages.reverse();
168
169        let mut mentions = channel_message_mention::Entity::find()
170            .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
171            .order_by_asc(channel_message_mention::Column::MessageId)
172            .order_by_asc(channel_message_mention::Column::StartOffset)
173            .stream(&*tx)
174            .await?;
175
176        let mut message_ix = 0;
177        while let Some(mention) = mentions.next().await {
178            let mention = mention?;
179            let message_id = mention.message_id.to_proto();
180            while let Some(message) = messages.get_mut(message_ix) {
181                if message.id < message_id {
182                    message_ix += 1;
183                } else {
184                    if message.id == message_id {
185                        message.mentions.push(proto::ChatMention {
186                            range: Some(proto::Range {
187                                start: mention.start_offset as u64,
188                                end: mention.end_offset as u64,
189                            }),
190                            user_id: mention.user_id.to_proto(),
191                        });
192                    }
193                    break;
194                }
195            }
196        }
197
198        Ok(messages)
199    }
200
201    /// Creates a new channel message.
202    pub async fn create_channel_message(
203        &self,
204        channel_id: ChannelId,
205        user_id: UserId,
206        body: &str,
207        mentions: &[proto::ChatMention],
208        timestamp: OffsetDateTime,
209        nonce: u128,
210    ) -> Result<CreatedChannelMessage> {
211        self.transaction(|tx| async move {
212            let channel = self.get_channel_internal(channel_id, &*tx).await?;
213            self.check_user_is_channel_participant(&channel, user_id, &*tx)
214                .await?;
215
216            let mut rows = channel_chat_participant::Entity::find()
217                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
218                .stream(&*tx)
219                .await?;
220
221            let mut is_participant = false;
222            let mut participant_connection_ids = Vec::new();
223            let mut participant_user_ids = Vec::new();
224            while let Some(row) = rows.next().await {
225                let row = row?;
226                if row.user_id == user_id {
227                    is_participant = true;
228                }
229                participant_user_ids.push(row.user_id);
230                participant_connection_ids.push(row.connection());
231            }
232            drop(rows);
233
234            if !is_participant {
235                Err(anyhow!("not a chat participant"))?;
236            }
237
238            let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
239            let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
240
241            let result = channel_message::Entity::insert(channel_message::ActiveModel {
242                channel_id: ActiveValue::Set(channel_id),
243                sender_id: ActiveValue::Set(user_id),
244                body: ActiveValue::Set(body.to_string()),
245                sent_at: ActiveValue::Set(timestamp),
246                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
247                id: ActiveValue::NotSet,
248            })
249            .on_conflict(
250                OnConflict::columns([
251                    channel_message::Column::SenderId,
252                    channel_message::Column::Nonce,
253                ])
254                .do_nothing()
255                .to_owned(),
256            )
257            .do_nothing()
258            .exec(&*tx)
259            .await?;
260
261            let message_id;
262            let mut notifications = Vec::new();
263            match result {
264                TryInsertResult::Inserted(result) => {
265                    message_id = result.last_insert_id;
266                    let mentioned_user_ids =
267                        mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
268
269                    let mentions = mentions
270                        .iter()
271                        .filter_map(|mention| {
272                            let range = mention.range.as_ref()?;
273                            if !body.is_char_boundary(range.start as usize)
274                                || !body.is_char_boundary(range.end as usize)
275                            {
276                                return None;
277                            }
278                            Some(channel_message_mention::ActiveModel {
279                                message_id: ActiveValue::Set(message_id),
280                                start_offset: ActiveValue::Set(range.start as i32),
281                                end_offset: ActiveValue::Set(range.end as i32),
282                                user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
283                            })
284                        })
285                        .collect::<Vec<_>>();
286                    if !mentions.is_empty() {
287                        channel_message_mention::Entity::insert_many(mentions)
288                            .exec(&*tx)
289                            .await?;
290                    }
291
292                    for mentioned_user in mentioned_user_ids {
293                        notifications.extend(
294                            self.create_notification(
295                                UserId::from_proto(mentioned_user),
296                                rpc::Notification::ChannelMessageMention {
297                                    message_id: message_id.to_proto(),
298                                    sender_id: user_id.to_proto(),
299                                    channel_id: channel_id.to_proto(),
300                                },
301                                false,
302                                &*tx,
303                            )
304                            .await?,
305                        );
306                    }
307
308                    self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
309                        .await?;
310                }
311                _ => {
312                    message_id = channel_message::Entity::find()
313                        .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
314                        .one(&*tx)
315                        .await?
316                        .ok_or_else(|| anyhow!("failed to insert message"))?
317                        .id;
318                }
319            }
320
321            let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
322            channel_members.retain(|member| !participant_user_ids.contains(member));
323
324            Ok(CreatedChannelMessage {
325                message_id,
326                participant_connection_ids,
327                channel_members,
328                notifications,
329            })
330        })
331        .await
332    }
333
334    pub async fn observe_channel_message(
335        &self,
336        channel_id: ChannelId,
337        user_id: UserId,
338        message_id: MessageId,
339    ) -> Result<NotificationBatch> {
340        self.transaction(|tx| async move {
341            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
342                .await?;
343            let mut batch = NotificationBatch::default();
344            batch.extend(
345                self.mark_notification_as_read(
346                    user_id,
347                    &Notification::ChannelMessageMention {
348                        message_id: message_id.to_proto(),
349                        sender_id: Default::default(),
350                        channel_id: Default::default(),
351                    },
352                    &*tx,
353                )
354                .await?,
355            );
356            Ok(batch)
357        })
358        .await
359    }
360
361    async fn observe_channel_message_internal(
362        &self,
363        channel_id: ChannelId,
364        user_id: UserId,
365        message_id: MessageId,
366        tx: &DatabaseTransaction,
367    ) -> Result<()> {
368        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
369            user_id: ActiveValue::Set(user_id),
370            channel_id: ActiveValue::Set(channel_id),
371            channel_message_id: ActiveValue::Set(message_id),
372        })
373        .on_conflict(
374            OnConflict::columns([
375                observed_channel_messages::Column::ChannelId,
376                observed_channel_messages::Column::UserId,
377            ])
378            .update_column(observed_channel_messages::Column::ChannelMessageId)
379            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
380            .to_owned(),
381        )
382        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
383        .exec_without_returning(&*tx)
384        .await?;
385        Ok(())
386    }
387
388    pub async fn latest_channel_messages(
389        &self,
390        channel_ids: &[ChannelId],
391        tx: &DatabaseTransaction,
392    ) -> Result<Vec<proto::ChannelMessageId>> {
393        let mut values = String::new();
394        for id in channel_ids {
395            if !values.is_empty() {
396                values.push_str(", ");
397            }
398            write!(&mut values, "({})", id).unwrap();
399        }
400
401        if values.is_empty() {
402            return Ok(Vec::default());
403        }
404
405        let sql = format!(
406            r#"
407            SELECT
408                *
409            FROM (
410                SELECT
411                    *,
412                    row_number() OVER (
413                        PARTITION BY channel_id
414                        ORDER BY id DESC
415                    ) as row_number
416                FROM channel_messages
417                WHERE
418                    channel_id in ({values})
419            ) AS messages
420            WHERE
421                row_number = 1
422            "#,
423        );
424
425        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
426        let mut last_messages = channel_message::Model::find_by_statement(stmt)
427            .stream(&*tx)
428            .await?;
429
430        let mut results = Vec::new();
431        while let Some(result) = last_messages.next().await {
432            let message = result?;
433            results.push(proto::ChannelMessageId {
434                channel_id: message.channel_id.to_proto(),
435                message_id: message.id.to_proto(),
436            });
437        }
438
439        Ok(results)
440    }
441
442    /// Removes the channel message with the given ID.
443    pub async fn remove_channel_message(
444        &self,
445        channel_id: ChannelId,
446        message_id: MessageId,
447        user_id: UserId,
448    ) -> Result<Vec<ConnectionId>> {
449        self.transaction(|tx| async move {
450            let mut rows = channel_chat_participant::Entity::find()
451                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
452                .stream(&*tx)
453                .await?;
454
455            let mut is_participant = false;
456            let mut participant_connection_ids = Vec::new();
457            while let Some(row) = rows.next().await {
458                let row = row?;
459                if row.user_id == user_id {
460                    is_participant = true;
461                }
462                participant_connection_ids.push(row.connection());
463            }
464            drop(rows);
465
466            if !is_participant {
467                Err(anyhow!("not a chat participant"))?;
468            }
469
470            let result = channel_message::Entity::delete_by_id(message_id)
471                .filter(channel_message::Column::SenderId.eq(user_id))
472                .exec(&*tx)
473                .await?;
474
475            if result.rows_affected == 0 {
476                let channel = self.get_channel_internal(channel_id, &*tx).await?;
477                if self
478                    .check_user_is_channel_admin(&channel, user_id, &*tx)
479                    .await
480                    .is_ok()
481                {
482                    let result = channel_message::Entity::delete_by_id(message_id)
483                        .exec(&*tx)
484                        .await?;
485                    if result.rows_affected == 0 {
486                        Err(anyhow!("no such message"))?;
487                    }
488                } else {
489                    Err(anyhow!("operation could not be completed"))?;
490                }
491            }
492
493            Ok(participant_connection_ids)
494        })
495        .await
496    }
497}