messages.rs

  1use super::*;
  2use time::OffsetDateTime;
  3
  4impl Database {
  5    pub async fn join_channel_chat(
  6        &self,
  7        channel_id: ChannelId,
  8        connection_id: ConnectionId,
  9        user_id: UserId,
 10    ) -> Result<()> {
 11        self.transaction(|tx| async move {
 12            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 13                .await?;
 14            channel_chat_participant::ActiveModel {
 15                id: ActiveValue::NotSet,
 16                channel_id: ActiveValue::Set(channel_id),
 17                user_id: ActiveValue::Set(user_id),
 18                connection_id: ActiveValue::Set(connection_id.id as i32),
 19                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 20            }
 21            .insert(&*tx)
 22            .await?;
 23            Ok(())
 24        })
 25        .await
 26    }
 27
 28    pub async fn channel_chat_connection_lost(
 29        &self,
 30        connection_id: ConnectionId,
 31        tx: &DatabaseTransaction,
 32    ) -> Result<()> {
 33        channel_chat_participant::Entity::delete_many()
 34            .filter(
 35                Condition::all()
 36                    .add(
 37                        channel_chat_participant::Column::ConnectionServerId
 38                            .eq(connection_id.owner_id),
 39                    )
 40                    .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
 41            )
 42            .exec(tx)
 43            .await?;
 44        Ok(())
 45    }
 46
 47    pub async fn leave_channel_chat(
 48        &self,
 49        channel_id: ChannelId,
 50        connection_id: ConnectionId,
 51        _user_id: UserId,
 52    ) -> Result<()> {
 53        self.transaction(|tx| async move {
 54            channel_chat_participant::Entity::delete_many()
 55                .filter(
 56                    Condition::all()
 57                        .add(
 58                            channel_chat_participant::Column::ConnectionServerId
 59                                .eq(connection_id.owner_id),
 60                        )
 61                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
 62                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
 63                )
 64                .exec(&*tx)
 65                .await?;
 66
 67            Ok(())
 68        })
 69        .await
 70    }
 71
 72    pub async fn get_channel_messages(
 73        &self,
 74        channel_id: ChannelId,
 75        user_id: UserId,
 76        count: usize,
 77        before_message_id: Option<MessageId>,
 78    ) -> Result<Vec<proto::ChannelMessage>> {
 79        self.transaction(|tx| async move {
 80            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 81                .await?;
 82
 83            let mut condition =
 84                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
 85
 86            if let Some(before_message_id) = before_message_id {
 87                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
 88            }
 89
 90            let mut rows = channel_message::Entity::find()
 91                .filter(condition)
 92                .limit(count as u64)
 93                .stream(&*tx)
 94                .await?;
 95
 96            let mut max_id = None;
 97            let mut messages = Vec::new();
 98            while let Some(row) = rows.next().await {
 99                let row = row?;
100
101                max_assign(&mut max_id, row.id);
102
103                let nonce = row.nonce.as_u64_pair();
104                messages.push(proto::ChannelMessage {
105                    id: row.id.to_proto(),
106                    sender_id: row.sender_id.to_proto(),
107                    body: row.body,
108                    timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
109                    nonce: Some(proto::Nonce {
110                        upper_half: nonce.0,
111                        lower_half: nonce.1,
112                    }),
113                });
114            }
115            drop(rows);
116
117            if let Some(max_id) = max_id {
118                let has_older_message = observed_channel_messages::Entity::find()
119                    .filter(
120                        observed_channel_messages::Column::UserId
121                            .eq(user_id)
122                            .and(observed_channel_messages::Column::ChannelId.eq(channel_id))
123                            .and(observed_channel_messages::Column::ChannelMessageId.lt(max_id)),
124                    )
125                    .one(&*tx)
126                    .await?
127                    .is_some();
128
129                if has_older_message {
130                    observed_channel_messages::Entity::update(
131                        observed_channel_messages::ActiveModel {
132                            user_id: ActiveValue::Unchanged(user_id),
133                            channel_id: ActiveValue::Unchanged(channel_id),
134                            channel_message_id: ActiveValue::Set(max_id),
135                        },
136                    )
137                    .exec(&*tx)
138                    .await?;
139                } else {
140                    observed_channel_messages::Entity::insert(
141                        observed_channel_messages::ActiveModel {
142                            user_id: ActiveValue::Set(user_id),
143                            channel_id: ActiveValue::Set(channel_id),
144                            channel_message_id: ActiveValue::Set(max_id),
145                        },
146                    )
147                    .on_conflict(
148                        OnConflict::columns([
149                            observed_channel_messages::Column::UserId,
150                            observed_channel_messages::Column::ChannelId,
151                        ])
152                        .update_columns([observed_channel_messages::Column::ChannelMessageId])
153                        .to_owned(),
154                    )
155                    .exec(&*tx)
156                    .await?;
157                }
158            }
159
160            Ok(messages)
161        })
162        .await
163    }
164
165    pub async fn create_channel_message(
166        &self,
167        channel_id: ChannelId,
168        user_id: UserId,
169        body: &str,
170        timestamp: OffsetDateTime,
171        nonce: u128,
172    ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
173        self.transaction(|tx| async move {
174            let mut rows = channel_chat_participant::Entity::find()
175                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
176                .stream(&*tx)
177                .await?;
178
179            let mut is_participant = false;
180            let mut participant_connection_ids = Vec::new();
181            let mut participant_user_ids = Vec::new();
182            while let Some(row) = rows.next().await {
183                let row = row?;
184                if row.user_id == user_id {
185                    is_participant = true;
186                }
187                participant_user_ids.push(row.user_id);
188                participant_connection_ids.push(row.connection());
189            }
190            drop(rows);
191
192            if !is_participant {
193                Err(anyhow!("not a chat participant"))?;
194            }
195
196            let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
197            let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
198
199            let message = channel_message::Entity::insert(channel_message::ActiveModel {
200                channel_id: ActiveValue::Set(channel_id),
201                sender_id: ActiveValue::Set(user_id),
202                body: ActiveValue::Set(body.to_string()),
203                sent_at: ActiveValue::Set(timestamp),
204                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
205                id: ActiveValue::NotSet,
206            })
207            .on_conflict(
208                OnConflict::column(channel_message::Column::Nonce)
209                    .update_column(channel_message::Column::Nonce)
210                    .to_owned(),
211            )
212            .exec(&*tx)
213            .await?;
214
215            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
216            enum QueryConnectionId {
217                ConnectionId,
218            }
219
220            // Observe this message for the sender
221            self.observe_channel_message_internal(
222                channel_id,
223                user_id,
224                message.last_insert_id,
225                &*tx,
226            )
227            .await?;
228
229            let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
230            channel_members.retain(|member| !participant_user_ids.contains(member));
231
232            Ok((
233                message.last_insert_id,
234                participant_connection_ids,
235                channel_members,
236            ))
237        })
238        .await
239    }
240
241    pub async fn observe_channel_message(
242        &self,
243        channel_id: ChannelId,
244        user_id: UserId,
245        message_id: MessageId,
246    ) -> Result<()> {
247        self.transaction(|tx| async move {
248            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
249                .await?;
250            Ok(())
251        })
252        .await
253    }
254
255    async fn observe_channel_message_internal(
256        &self,
257        channel_id: ChannelId,
258        user_id: UserId,
259        message_id: MessageId,
260        tx: &DatabaseTransaction,
261    ) -> Result<()> {
262        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
263            user_id: ActiveValue::Set(user_id),
264            channel_id: ActiveValue::Set(channel_id),
265            channel_message_id: ActiveValue::Set(message_id),
266        })
267        .on_conflict(
268            OnConflict::columns([
269                observed_channel_messages::Column::ChannelId,
270                observed_channel_messages::Column::UserId,
271            ])
272            .update_column(observed_channel_messages::Column::ChannelMessageId)
273            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
274            .to_owned(),
275        )
276        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
277        .exec_without_returning(&*tx)
278        .await?;
279        Ok(())
280    }
281
282    pub async fn unseen_channel_messages(
283        &self,
284        user_id: UserId,
285        channel_ids: &[ChannelId],
286        tx: &DatabaseTransaction,
287    ) -> Result<Vec<proto::UnseenChannelMessage>> {
288        let mut observed_messages_by_channel_id = HashMap::default();
289        let mut rows = observed_channel_messages::Entity::find()
290            .filter(observed_channel_messages::Column::UserId.eq(user_id))
291            .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
292            .stream(&*tx)
293            .await?;
294
295        while let Some(row) = rows.next().await {
296            let row = row?;
297            observed_messages_by_channel_id.insert(row.channel_id, row);
298        }
299        drop(rows);
300        let mut values = String::new();
301        for id in channel_ids {
302            if !values.is_empty() {
303                values.push_str(", ");
304            }
305            write!(&mut values, "({})", id).unwrap();
306        }
307
308        if values.is_empty() {
309            return Ok(Default::default());
310        }
311
312        let sql = format!(
313            r#"
314            SELECT
315                *
316            FROM (
317                SELECT
318                    *,
319                    row_number() OVER (
320                        PARTITION BY channel_id
321                        ORDER BY id DESC
322                    ) as row_number
323                FROM channel_messages
324                WHERE
325                    channel_id in ({values})
326            ) AS messages
327            WHERE
328                row_number = 1
329            "#,
330        );
331
332        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
333        let last_messages = channel_message::Model::find_by_statement(stmt)
334            .all(&*tx)
335            .await?;
336
337        let mut changes = Vec::new();
338        for last_message in last_messages {
339            if let Some(observed_message) =
340                observed_messages_by_channel_id.get(&last_message.channel_id)
341            {
342                if observed_message.channel_message_id == last_message.id {
343                    continue;
344                }
345            }
346            changes.push(proto::UnseenChannelMessage {
347                channel_id: last_message.channel_id.to_proto(),
348                message_id: last_message.id.to_proto(),
349            });
350        }
351
352        Ok(changes)
353    }
354
355    pub async fn remove_channel_message(
356        &self,
357        channel_id: ChannelId,
358        message_id: MessageId,
359        user_id: UserId,
360    ) -> Result<Vec<ConnectionId>> {
361        self.transaction(|tx| async move {
362            let mut rows = channel_chat_participant::Entity::find()
363                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
364                .stream(&*tx)
365                .await?;
366
367            let mut is_participant = false;
368            let mut participant_connection_ids = Vec::new();
369            while let Some(row) = rows.next().await {
370                let row = row?;
371                if row.user_id == user_id {
372                    is_participant = true;
373                }
374                participant_connection_ids.push(row.connection());
375            }
376            drop(rows);
377
378            if !is_participant {
379                Err(anyhow!("not a chat participant"))?;
380            }
381
382            let result = channel_message::Entity::delete_by_id(message_id)
383                .filter(channel_message::Column::SenderId.eq(user_id))
384                .exec(&*tx)
385                .await?;
386            if result.rows_affected == 0 {
387                Err(anyhow!("no such message"))?;
388            }
389
390            Ok(participant_connection_ids)
391        })
392        .await
393    }
394}