messages.rs

  1use super::*;
  2use futures::Stream;
  3use sea_orm::TryInsertResult;
  4use time::OffsetDateTime;
  5
  6impl Database {
  7    pub async fn join_channel_chat(
  8        &self,
  9        channel_id: ChannelId,
 10        connection_id: ConnectionId,
 11        user_id: UserId,
 12    ) -> Result<()> {
 13        self.transaction(|tx| async move {
 14            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 15                .await?;
 16            channel_chat_participant::ActiveModel {
 17                id: ActiveValue::NotSet,
 18                channel_id: ActiveValue::Set(channel_id),
 19                user_id: ActiveValue::Set(user_id),
 20                connection_id: ActiveValue::Set(connection_id.id as i32),
 21                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 22            }
 23            .insert(&*tx)
 24            .await?;
 25            Ok(())
 26        })
 27        .await
 28    }
 29
 30    pub async fn channel_chat_connection_lost(
 31        &self,
 32        connection_id: ConnectionId,
 33        tx: &DatabaseTransaction,
 34    ) -> Result<()> {
 35        channel_chat_participant::Entity::delete_many()
 36            .filter(
 37                Condition::all()
 38                    .add(
 39                        channel_chat_participant::Column::ConnectionServerId
 40                            .eq(connection_id.owner_id),
 41                    )
 42                    .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
 43            )
 44            .exec(tx)
 45            .await?;
 46        Ok(())
 47    }
 48
 49    pub async fn leave_channel_chat(
 50        &self,
 51        channel_id: ChannelId,
 52        connection_id: ConnectionId,
 53        _user_id: UserId,
 54    ) -> Result<()> {
 55        self.transaction(|tx| async move {
 56            channel_chat_participant::Entity::delete_many()
 57                .filter(
 58                    Condition::all()
 59                        .add(
 60                            channel_chat_participant::Column::ConnectionServerId
 61                                .eq(connection_id.owner_id),
 62                        )
 63                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
 64                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
 65                )
 66                .exec(&*tx)
 67                .await?;
 68
 69            Ok(())
 70        })
 71        .await
 72    }
 73
 74    pub async fn get_channel_messages(
 75        &self,
 76        channel_id: ChannelId,
 77        user_id: UserId,
 78        count: usize,
 79        before_message_id: Option<MessageId>,
 80    ) -> Result<Vec<proto::ChannelMessage>> {
 81        self.transaction(|tx| async move {
 82            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 83                .await?;
 84
 85            let mut condition =
 86                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
 87
 88            if let Some(before_message_id) = before_message_id {
 89                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
 90            }
 91
 92            let rows = channel_message::Entity::find()
 93                .filter(condition)
 94                .order_by_desc(channel_message::Column::Id)
 95                .limit(count as u64)
 96                .stream(&*tx)
 97                .await?;
 98
 99            self.load_channel_messages(rows, &*tx).await
100        })
101        .await
102    }
103
104    pub async fn get_channel_messages_by_id(
105        &self,
106        user_id: UserId,
107        message_ids: &[MessageId],
108    ) -> Result<Vec<proto::ChannelMessage>> {
109        self.transaction(|tx| async move {
110            let rows = channel_message::Entity::find()
111                .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
112                .order_by_desc(channel_message::Column::Id)
113                .stream(&*tx)
114                .await?;
115
116            let mut channel_ids = HashSet::<ChannelId>::default();
117            let messages = self
118                .load_channel_messages(
119                    rows.map(|row| {
120                        row.map(|row| {
121                            channel_ids.insert(row.channel_id);
122                            row
123                        })
124                    }),
125                    &*tx,
126                )
127                .await?;
128
129            for channel_id in channel_ids {
130                self.check_user_is_channel_member(channel_id, user_id, &*tx)
131                    .await?;
132            }
133
134            Ok(messages)
135        })
136        .await
137    }
138
139    async fn load_channel_messages(
140        &self,
141        mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
142        tx: &DatabaseTransaction,
143    ) -> Result<Vec<proto::ChannelMessage>> {
144        let mut messages = Vec::new();
145        while let Some(row) = rows.next().await {
146            let row = row?;
147            let nonce = row.nonce.as_u64_pair();
148            messages.push(proto::ChannelMessage {
149                id: row.id.to_proto(),
150                sender_id: row.sender_id.to_proto(),
151                body: row.body,
152                timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
153                mentions: vec![],
154                nonce: Some(proto::Nonce {
155                    upper_half: nonce.0,
156                    lower_half: nonce.1,
157                }),
158            });
159        }
160        drop(rows);
161        messages.reverse();
162
163        let mut mentions = channel_message_mention::Entity::find()
164            .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
165            .order_by_asc(channel_message_mention::Column::MessageId)
166            .order_by_asc(channel_message_mention::Column::StartOffset)
167            .stream(&*tx)
168            .await?;
169
170        let mut message_ix = 0;
171        while let Some(mention) = mentions.next().await {
172            let mention = mention?;
173            let message_id = mention.message_id.to_proto();
174            while let Some(message) = messages.get_mut(message_ix) {
175                if message.id < message_id {
176                    message_ix += 1;
177                } else {
178                    if message.id == message_id {
179                        message.mentions.push(proto::ChatMention {
180                            range: Some(proto::Range {
181                                start: mention.start_offset as u64,
182                                end: mention.end_offset as u64,
183                            }),
184                            user_id: mention.user_id.to_proto(),
185                        });
186                    }
187                    break;
188                }
189            }
190        }
191
192        Ok(messages)
193    }
194
195    pub async fn create_channel_message(
196        &self,
197        channel_id: ChannelId,
198        user_id: UserId,
199        body: &str,
200        mentions: &[proto::ChatMention],
201        timestamp: OffsetDateTime,
202        nonce: u128,
203    ) -> Result<CreatedChannelMessage> {
204        self.transaction(|tx| async move {
205            let mut rows = channel_chat_participant::Entity::find()
206                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
207                .stream(&*tx)
208                .await?;
209
210            let mut is_participant = false;
211            let mut participant_connection_ids = Vec::new();
212            let mut participant_user_ids = Vec::new();
213            while let Some(row) = rows.next().await {
214                let row = row?;
215                if row.user_id == user_id {
216                    is_participant = true;
217                }
218                participant_user_ids.push(row.user_id);
219                participant_connection_ids.push(row.connection());
220            }
221            drop(rows);
222
223            if !is_participant {
224                Err(anyhow!("not a chat participant"))?;
225            }
226
227            let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
228            let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
229
230            let result = channel_message::Entity::insert(channel_message::ActiveModel {
231                channel_id: ActiveValue::Set(channel_id),
232                sender_id: ActiveValue::Set(user_id),
233                body: ActiveValue::Set(body.to_string()),
234                sent_at: ActiveValue::Set(timestamp),
235                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
236                id: ActiveValue::NotSet,
237            })
238            .on_conflict(
239                OnConflict::columns([
240                    channel_message::Column::SenderId,
241                    channel_message::Column::Nonce,
242                ])
243                .do_nothing()
244                .to_owned(),
245            )
246            .do_nothing()
247            .exec(&*tx)
248            .await?;
249
250            let message_id;
251            let mut notifications = Vec::new();
252            match result {
253                TryInsertResult::Inserted(result) => {
254                    message_id = result.last_insert_id;
255                    let mentioned_user_ids =
256                        mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
257                    let mentions = mentions
258                        .iter()
259                        .filter_map(|mention| {
260                            let range = mention.range.as_ref()?;
261                            if !body.is_char_boundary(range.start as usize)
262                                || !body.is_char_boundary(range.end as usize)
263                            {
264                                return None;
265                            }
266                            Some(channel_message_mention::ActiveModel {
267                                message_id: ActiveValue::Set(message_id),
268                                start_offset: ActiveValue::Set(range.start as i32),
269                                end_offset: ActiveValue::Set(range.end as i32),
270                                user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
271                            })
272                        })
273                        .collect::<Vec<_>>();
274                    if !mentions.is_empty() {
275                        channel_message_mention::Entity::insert_many(mentions)
276                            .exec(&*tx)
277                            .await?;
278                    }
279
280                    for mentioned_user in mentioned_user_ids {
281                        notifications.extend(
282                            self.create_notification(
283                                UserId::from_proto(mentioned_user),
284                                rpc::Notification::ChannelMessageMention {
285                                    message_id: message_id.to_proto(),
286                                    sender_id: user_id.to_proto(),
287                                    channel_id: channel_id.to_proto(),
288                                },
289                                false,
290                                &*tx,
291                            )
292                            .await?,
293                        );
294                    }
295
296                    self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
297                        .await?;
298                }
299                _ => {
300                    message_id = channel_message::Entity::find()
301                        .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
302                        .one(&*tx)
303                        .await?
304                        .ok_or_else(|| anyhow!("failed to insert message"))?
305                        .id;
306                }
307            }
308
309            let mut channel_members = self
310                .get_channel_participants_internal(channel_id, &*tx)
311                .await?;
312            channel_members.retain(|member| !participant_user_ids.contains(member));
313
314            Ok(CreatedChannelMessage {
315                message_id,
316                participant_connection_ids,
317                channel_members,
318                notifications,
319            })
320        })
321        .await
322    }
323
324    pub async fn observe_channel_message(
325        &self,
326        channel_id: ChannelId,
327        user_id: UserId,
328        message_id: MessageId,
329    ) -> Result<()> {
330        self.transaction(|tx| async move {
331            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
332                .await?;
333            Ok(())
334        })
335        .await
336    }
337
338    async fn observe_channel_message_internal(
339        &self,
340        channel_id: ChannelId,
341        user_id: UserId,
342        message_id: MessageId,
343        tx: &DatabaseTransaction,
344    ) -> Result<()> {
345        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
346            user_id: ActiveValue::Set(user_id),
347            channel_id: ActiveValue::Set(channel_id),
348            channel_message_id: ActiveValue::Set(message_id),
349        })
350        .on_conflict(
351            OnConflict::columns([
352                observed_channel_messages::Column::ChannelId,
353                observed_channel_messages::Column::UserId,
354            ])
355            .update_column(observed_channel_messages::Column::ChannelMessageId)
356            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
357            .to_owned(),
358        )
359        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
360        .exec_without_returning(&*tx)
361        .await?;
362        Ok(())
363    }
364
365    pub async fn unseen_channel_messages(
366        &self,
367        user_id: UserId,
368        channel_ids: &[ChannelId],
369        tx: &DatabaseTransaction,
370    ) -> Result<Vec<proto::UnseenChannelMessage>> {
371        let mut observed_messages_by_channel_id = HashMap::default();
372        let mut rows = observed_channel_messages::Entity::find()
373            .filter(observed_channel_messages::Column::UserId.eq(user_id))
374            .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
375            .stream(&*tx)
376            .await?;
377
378        while let Some(row) = rows.next().await {
379            let row = row?;
380            observed_messages_by_channel_id.insert(row.channel_id, row);
381        }
382        drop(rows);
383        let mut values = String::new();
384        for id in channel_ids {
385            if !values.is_empty() {
386                values.push_str(", ");
387            }
388            write!(&mut values, "({})", id).unwrap();
389        }
390
391        if values.is_empty() {
392            return Ok(Default::default());
393        }
394
395        let sql = format!(
396            r#"
397            SELECT
398                *
399            FROM (
400                SELECT
401                    *,
402                    row_number() OVER (
403                        PARTITION BY channel_id
404                        ORDER BY id DESC
405                    ) as row_number
406                FROM channel_messages
407                WHERE
408                    channel_id in ({values})
409            ) AS messages
410            WHERE
411                row_number = 1
412            "#,
413        );
414
415        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
416        let last_messages = channel_message::Model::find_by_statement(stmt)
417            .all(&*tx)
418            .await?;
419
420        let mut changes = Vec::new();
421        for last_message in last_messages {
422            if let Some(observed_message) =
423                observed_messages_by_channel_id.get(&last_message.channel_id)
424            {
425                if observed_message.channel_message_id == last_message.id {
426                    continue;
427                }
428            }
429            changes.push(proto::UnseenChannelMessage {
430                channel_id: last_message.channel_id.to_proto(),
431                message_id: last_message.id.to_proto(),
432            });
433        }
434
435        Ok(changes)
436    }
437
438    pub async fn remove_channel_message(
439        &self,
440        channel_id: ChannelId,
441        message_id: MessageId,
442        user_id: UserId,
443    ) -> Result<Vec<ConnectionId>> {
444        self.transaction(|tx| async move {
445            let mut rows = channel_chat_participant::Entity::find()
446                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
447                .stream(&*tx)
448                .await?;
449
450            let mut is_participant = false;
451            let mut participant_connection_ids = Vec::new();
452            while let Some(row) = rows.next().await {
453                let row = row?;
454                if row.user_id == user_id {
455                    is_participant = true;
456                }
457                participant_connection_ids.push(row.connection());
458            }
459            drop(rows);
460
461            if !is_participant {
462                Err(anyhow!("not a chat participant"))?;
463            }
464
465            let result = channel_message::Entity::delete_by_id(message_id)
466                .filter(channel_message::Column::SenderId.eq(user_id))
467                .exec(&*tx)
468                .await?;
469
470            if result.rows_affected == 0 {
471                if self
472                    .check_user_is_channel_admin(channel_id, user_id, &*tx)
473                    .await
474                    .is_ok()
475                {
476                    let result = channel_message::Entity::delete_by_id(message_id)
477                        .exec(&*tx)
478                        .await?;
479                    if result.rows_affected == 0 {
480                        Err(anyhow!("no such message"))?;
481                    }
482                } else {
483                    Err(anyhow!("operation could not be completed"))?;
484                }
485            }
486
487            Ok(participant_connection_ids)
488        })
489        .await
490    }
491}