messages.rs

  1use super::*;
  2use time::OffsetDateTime;
  3
  4impl Database {
  5    pub async fn join_channel_chat(
  6        &self,
  7        channel_id: ChannelId,
  8        connection_id: ConnectionId,
  9        user_id: UserId,
 10    ) -> Result<()> {
 11        self.transaction(|tx| async move {
 12            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 13                .await?;
 14            channel_chat_participant::ActiveModel {
 15                id: ActiveValue::NotSet,
 16                channel_id: ActiveValue::Set(channel_id),
 17                user_id: ActiveValue::Set(user_id),
 18                connection_id: ActiveValue::Set(connection_id.id as i32),
 19                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 20            }
 21            .insert(&*tx)
 22            .await?;
 23            Ok(())
 24        })
 25        .await
 26    }
 27
 28    pub async fn channel_chat_connection_lost(
 29        &self,
 30        connection_id: ConnectionId,
 31        tx: &DatabaseTransaction,
 32    ) -> Result<()> {
 33        channel_chat_participant::Entity::delete_many()
 34            .filter(
 35                Condition::all()
 36                    .add(
 37                        channel_chat_participant::Column::ConnectionServerId
 38                            .eq(connection_id.owner_id),
 39                    )
 40                    .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
 41            )
 42            .exec(tx)
 43            .await?;
 44        Ok(())
 45    }
 46
 47    pub async fn leave_channel_chat(
 48        &self,
 49        channel_id: ChannelId,
 50        connection_id: ConnectionId,
 51        _user_id: UserId,
 52    ) -> Result<()> {
 53        self.transaction(|tx| async move {
 54            channel_chat_participant::Entity::delete_many()
 55                .filter(
 56                    Condition::all()
 57                        .add(
 58                            channel_chat_participant::Column::ConnectionServerId
 59                                .eq(connection_id.owner_id),
 60                        )
 61                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
 62                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
 63                )
 64                .exec(&*tx)
 65                .await?;
 66
 67            Ok(())
 68        })
 69        .await
 70    }
 71
 72    pub async fn get_channel_messages(
 73        &self,
 74        channel_id: ChannelId,
 75        user_id: UserId,
 76        count: usize,
 77        before_message_id: Option<MessageId>,
 78    ) -> Result<Vec<proto::ChannelMessage>> {
 79        self.transaction(|tx| async move {
 80            self.check_user_is_channel_member(channel_id, user_id, &*tx)
 81                .await?;
 82
 83            let mut condition =
 84                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
 85
 86            if let Some(before_message_id) = before_message_id {
 87                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
 88            }
 89
 90            let mut rows = channel_message::Entity::find()
 91                .filter(condition)
 92                .order_by_desc(channel_message::Column::Id)
 93                .limit(count as u64)
 94                .stream(&*tx)
 95                .await?;
 96
 97            let mut messages = Vec::new();
 98            while let Some(row) = rows.next().await {
 99                let row = row?;
100                let nonce = row.nonce.as_u64_pair();
101                messages.push(proto::ChannelMessage {
102                    id: row.id.to_proto(),
103                    sender_id: row.sender_id.to_proto(),
104                    body: row.body,
105                    timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
106                    nonce: Some(proto::Nonce {
107                        upper_half: nonce.0,
108                        lower_half: nonce.1,
109                    }),
110                });
111            }
112            drop(rows);
113            messages.reverse();
114            Ok(messages)
115        })
116        .await
117    }
118
119    pub async fn create_channel_message(
120        &self,
121        channel_id: ChannelId,
122        user_id: UserId,
123        body: &str,
124        timestamp: OffsetDateTime,
125        nonce: u128,
126    ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
127        self.transaction(|tx| async move {
128            let mut rows = channel_chat_participant::Entity::find()
129                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
130                .stream(&*tx)
131                .await?;
132
133            let mut is_participant = false;
134            let mut participant_connection_ids = Vec::new();
135            let mut participant_user_ids = Vec::new();
136            while let Some(row) = rows.next().await {
137                let row = row?;
138                if row.user_id == user_id {
139                    is_participant = true;
140                }
141                participant_user_ids.push(row.user_id);
142                participant_connection_ids.push(row.connection());
143            }
144            drop(rows);
145
146            if !is_participant {
147                Err(anyhow!("not a chat participant"))?;
148            }
149
150            let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
151            let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
152
153            let message = channel_message::Entity::insert(channel_message::ActiveModel {
154                channel_id: ActiveValue::Set(channel_id),
155                sender_id: ActiveValue::Set(user_id),
156                body: ActiveValue::Set(body.to_string()),
157                sent_at: ActiveValue::Set(timestamp),
158                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
159                id: ActiveValue::NotSet,
160            })
161            .on_conflict(
162                OnConflict::column(channel_message::Column::Nonce)
163                    .update_column(channel_message::Column::Nonce)
164                    .to_owned(),
165            )
166            .exec(&*tx)
167            .await?;
168
169            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
170            enum QueryConnectionId {
171                ConnectionId,
172            }
173
174            // Observe this message for the sender
175            self.observe_channel_message_internal(
176                channel_id,
177                user_id,
178                message.last_insert_id,
179                &*tx,
180            )
181            .await?;
182
183            let mut channel_members = self
184                .get_channel_participants_internal(channel_id, &*tx)
185                .await?;
186            channel_members.retain(|member| !participant_user_ids.contains(member));
187
188            Ok((
189                message.last_insert_id,
190                participant_connection_ids,
191                channel_members,
192            ))
193        })
194        .await
195    }
196
197    pub async fn observe_channel_message(
198        &self,
199        channel_id: ChannelId,
200        user_id: UserId,
201        message_id: MessageId,
202    ) -> Result<()> {
203        self.transaction(|tx| async move {
204            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
205                .await?;
206            Ok(())
207        })
208        .await
209    }
210
211    async fn observe_channel_message_internal(
212        &self,
213        channel_id: ChannelId,
214        user_id: UserId,
215        message_id: MessageId,
216        tx: &DatabaseTransaction,
217    ) -> Result<()> {
218        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
219            user_id: ActiveValue::Set(user_id),
220            channel_id: ActiveValue::Set(channel_id),
221            channel_message_id: ActiveValue::Set(message_id),
222        })
223        .on_conflict(
224            OnConflict::columns([
225                observed_channel_messages::Column::ChannelId,
226                observed_channel_messages::Column::UserId,
227            ])
228            .update_column(observed_channel_messages::Column::ChannelMessageId)
229            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
230            .to_owned(),
231        )
232        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
233        .exec_without_returning(&*tx)
234        .await?;
235        Ok(())
236    }
237
238    pub async fn unseen_channel_messages(
239        &self,
240        user_id: UserId,
241        channel_ids: &[ChannelId],
242        tx: &DatabaseTransaction,
243    ) -> Result<Vec<proto::UnseenChannelMessage>> {
244        let mut observed_messages_by_channel_id = HashMap::default();
245        let mut rows = observed_channel_messages::Entity::find()
246            .filter(observed_channel_messages::Column::UserId.eq(user_id))
247            .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
248            .stream(&*tx)
249            .await?;
250
251        while let Some(row) = rows.next().await {
252            let row = row?;
253            observed_messages_by_channel_id.insert(row.channel_id, row);
254        }
255        drop(rows);
256        let mut values = String::new();
257        for id in channel_ids {
258            if !values.is_empty() {
259                values.push_str(", ");
260            }
261            write!(&mut values, "({})", id).unwrap();
262        }
263
264        if values.is_empty() {
265            return Ok(Default::default());
266        }
267
268        let sql = format!(
269            r#"
270            SELECT
271                *
272            FROM (
273                SELECT
274                    *,
275                    row_number() OVER (
276                        PARTITION BY channel_id
277                        ORDER BY id DESC
278                    ) as row_number
279                FROM channel_messages
280                WHERE
281                    channel_id in ({values})
282            ) AS messages
283            WHERE
284                row_number = 1
285            "#,
286        );
287
288        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
289        let last_messages = channel_message::Model::find_by_statement(stmt)
290            .all(&*tx)
291            .await?;
292
293        let mut changes = Vec::new();
294        for last_message in last_messages {
295            if let Some(observed_message) =
296                observed_messages_by_channel_id.get(&last_message.channel_id)
297            {
298                if observed_message.channel_message_id == last_message.id {
299                    continue;
300                }
301            }
302            changes.push(proto::UnseenChannelMessage {
303                channel_id: last_message.channel_id.to_proto(),
304                message_id: last_message.id.to_proto(),
305            });
306        }
307
308        Ok(changes)
309    }
310
311    pub async fn remove_channel_message(
312        &self,
313        channel_id: ChannelId,
314        message_id: MessageId,
315        user_id: UserId,
316    ) -> Result<Vec<ConnectionId>> {
317        self.transaction(|tx| async move {
318            let mut rows = channel_chat_participant::Entity::find()
319                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
320                .stream(&*tx)
321                .await?;
322
323            let mut is_participant = false;
324            let mut participant_connection_ids = Vec::new();
325            while let Some(row) = rows.next().await {
326                let row = row?;
327                if row.user_id == user_id {
328                    is_participant = true;
329                }
330                participant_connection_ids.push(row.connection());
331            }
332            drop(rows);
333
334            if !is_participant {
335                Err(anyhow!("not a chat participant"))?;
336            }
337
338            let result = channel_message::Entity::delete_by_id(message_id)
339                .filter(channel_message::Column::SenderId.eq(user_id))
340                .exec(&*tx)
341                .await?;
342
343            if result.rows_affected == 0 {
344                if self
345                    .check_user_is_channel_admin(channel_id, user_id, &*tx)
346                    .await
347                    .is_ok()
348                {
349                    let result = channel_message::Entity::delete_by_id(message_id)
350                        .exec(&*tx)
351                        .await?;
352                    if result.rows_affected == 0 {
353                        Err(anyhow!("no such message"))?;
354                    }
355                } else {
356                    Err(anyhow!("operation could not be completed"))?;
357                }
358            }
359
360            Ok(participant_connection_ids)
361        })
362        .await
363    }
364}