1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 /// Inserts a record representing a user joining the chat for a given channel.
8 pub async fn join_channel_chat(
9 &self,
10 channel_id: ChannelId,
11 connection_id: ConnectionId,
12 user_id: UserId,
13 ) -> Result<()> {
14 self.transaction(|tx| async move {
15 let channel = self.get_channel_internal(channel_id, &*tx).await?;
16 self.check_user_is_channel_participant(&channel, user_id, &*tx)
17 .await?;
18 channel_chat_participant::ActiveModel {
19 id: ActiveValue::NotSet,
20 channel_id: ActiveValue::Set(channel_id),
21 user_id: ActiveValue::Set(user_id),
22 connection_id: ActiveValue::Set(connection_id.id as i32),
23 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
24 }
25 .insert(&*tx)
26 .await?;
27 Ok(())
28 })
29 .await
30 }
31
32 /// Removes `channel_chat_participant` records associated with the given connection ID.
33 pub async fn channel_chat_connection_lost(
34 &self,
35 connection_id: ConnectionId,
36 tx: &DatabaseTransaction,
37 ) -> Result<()> {
38 channel_chat_participant::Entity::delete_many()
39 .filter(
40 Condition::all()
41 .add(
42 channel_chat_participant::Column::ConnectionServerId
43 .eq(connection_id.owner_id),
44 )
45 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
46 )
47 .exec(tx)
48 .await?;
49 Ok(())
50 }
51
52 /// Removes `channel_chat_participant` records associated with the given user ID so they
53 /// will no longer get chat notifications.
54 pub async fn leave_channel_chat(
55 &self,
56 channel_id: ChannelId,
57 connection_id: ConnectionId,
58 _user_id: UserId,
59 ) -> Result<()> {
60 self.transaction(|tx| async move {
61 channel_chat_participant::Entity::delete_many()
62 .filter(
63 Condition::all()
64 .add(
65 channel_chat_participant::Column::ConnectionServerId
66 .eq(connection_id.owner_id),
67 )
68 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
69 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
70 )
71 .exec(&*tx)
72 .await?;
73
74 Ok(())
75 })
76 .await
77 }
78
79 /// Retrieves the messages in the specified channel.
80 ///
81 /// Use `before_message_id` to paginate through the channel's messages.
82 pub async fn get_channel_messages(
83 &self,
84 channel_id: ChannelId,
85 user_id: UserId,
86 count: usize,
87 before_message_id: Option<MessageId>,
88 ) -> Result<Vec<proto::ChannelMessage>> {
89 self.transaction(|tx| async move {
90 let channel = self.get_channel_internal(channel_id, &*tx).await?;
91 self.check_user_is_channel_participant(&channel, user_id, &*tx)
92 .await?;
93
94 let mut condition =
95 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
96
97 if let Some(before_message_id) = before_message_id {
98 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
99 }
100
101 let rows = channel_message::Entity::find()
102 .filter(condition)
103 .order_by_desc(channel_message::Column::Id)
104 .limit(count as u64)
105 .all(&*tx)
106 .await?;
107
108 self.load_channel_messages(rows, &*tx).await
109 })
110 .await
111 }
112
113 /// Returns the channel messages with the given IDs.
114 pub async fn get_channel_messages_by_id(
115 &self,
116 user_id: UserId,
117 message_ids: &[MessageId],
118 ) -> Result<Vec<proto::ChannelMessage>> {
119 self.transaction(|tx| async move {
120 let rows = channel_message::Entity::find()
121 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122 .order_by_desc(channel_message::Column::Id)
123 .all(&*tx)
124 .await?;
125
126 let mut channels = HashMap::<ChannelId, channel::Model>::default();
127 for row in &rows {
128 channels.insert(
129 row.channel_id,
130 self.get_channel_internal(row.channel_id, &*tx).await?,
131 );
132 }
133
134 for (_, channel) in channels {
135 self.check_user_is_channel_participant(&channel, user_id, &*tx)
136 .await?;
137 }
138
139 let messages = self.load_channel_messages(rows, &*tx).await?;
140 Ok(messages)
141 })
142 .await
143 }
144
145 async fn load_channel_messages(
146 &self,
147 rows: Vec<channel_message::Model>,
148 tx: &DatabaseTransaction,
149 ) -> Result<Vec<proto::ChannelMessage>> {
150 let mut messages = rows
151 .into_iter()
152 .map(|row| {
153 let nonce = row.nonce.as_u64_pair();
154 proto::ChannelMessage {
155 id: row.id.to_proto(),
156 sender_id: row.sender_id.to_proto(),
157 body: row.body,
158 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159 mentions: vec![],
160 nonce: Some(proto::Nonce {
161 upper_half: nonce.0,
162 lower_half: nonce.1,
163 }),
164 reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
165 }
166 })
167 .collect::<Vec<_>>();
168 messages.reverse();
169
170 let mut mentions = channel_message_mention::Entity::find()
171 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
172 .order_by_asc(channel_message_mention::Column::MessageId)
173 .order_by_asc(channel_message_mention::Column::StartOffset)
174 .stream(&*tx)
175 .await?;
176
177 let mut message_ix = 0;
178 while let Some(mention) = mentions.next().await {
179 let mention = mention?;
180 let message_id = mention.message_id.to_proto();
181 while let Some(message) = messages.get_mut(message_ix) {
182 if message.id < message_id {
183 message_ix += 1;
184 } else {
185 if message.id == message_id {
186 message.mentions.push(proto::ChatMention {
187 range: Some(proto::Range {
188 start: mention.start_offset as u64,
189 end: mention.end_offset as u64,
190 }),
191 user_id: mention.user_id.to_proto(),
192 });
193 }
194 break;
195 }
196 }
197 }
198
199 Ok(messages)
200 }
201
202 /// Creates a new channel message.
203 pub async fn create_channel_message(
204 &self,
205 channel_id: ChannelId,
206 user_id: UserId,
207 body: &str,
208 mentions: &[proto::ChatMention],
209 timestamp: OffsetDateTime,
210 nonce: u128,
211 reply_to_message_id: Option<MessageId>,
212 ) -> Result<CreatedChannelMessage> {
213 self.transaction(|tx| async move {
214 let channel = self.get_channel_internal(channel_id, &*tx).await?;
215 self.check_user_is_channel_participant(&channel, user_id, &*tx)
216 .await?;
217
218 let mut rows = channel_chat_participant::Entity::find()
219 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
220 .stream(&*tx)
221 .await?;
222
223 let mut is_participant = false;
224 let mut participant_connection_ids = Vec::new();
225 let mut participant_user_ids = Vec::new();
226 while let Some(row) = rows.next().await {
227 let row = row?;
228 if row.user_id == user_id {
229 is_participant = true;
230 }
231 participant_user_ids.push(row.user_id);
232 participant_connection_ids.push(row.connection());
233 }
234 drop(rows);
235
236 if !is_participant {
237 Err(anyhow!("not a chat participant"))?;
238 }
239
240 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
241 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
242
243 let result = channel_message::Entity::insert(channel_message::ActiveModel {
244 channel_id: ActiveValue::Set(channel_id),
245 sender_id: ActiveValue::Set(user_id),
246 body: ActiveValue::Set(body.to_string()),
247 sent_at: ActiveValue::Set(timestamp),
248 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
249 id: ActiveValue::NotSet,
250 reply_to_message_id: ActiveValue::Set(reply_to_message_id),
251 })
252 .on_conflict(
253 OnConflict::columns([
254 channel_message::Column::SenderId,
255 channel_message::Column::Nonce,
256 ])
257 .do_nothing()
258 .to_owned(),
259 )
260 .do_nothing()
261 .exec(&*tx)
262 .await?;
263
264 let message_id;
265 let mut notifications = Vec::new();
266 match result {
267 TryInsertResult::Inserted(result) => {
268 message_id = result.last_insert_id;
269 let mentioned_user_ids =
270 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
271
272 let mentions = mentions
273 .iter()
274 .filter_map(|mention| {
275 let range = mention.range.as_ref()?;
276 if !body.is_char_boundary(range.start as usize)
277 || !body.is_char_boundary(range.end as usize)
278 {
279 return None;
280 }
281 Some(channel_message_mention::ActiveModel {
282 message_id: ActiveValue::Set(message_id),
283 start_offset: ActiveValue::Set(range.start as i32),
284 end_offset: ActiveValue::Set(range.end as i32),
285 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
286 })
287 })
288 .collect::<Vec<_>>();
289 if !mentions.is_empty() {
290 channel_message_mention::Entity::insert_many(mentions)
291 .exec(&*tx)
292 .await?;
293 }
294
295 for mentioned_user in mentioned_user_ids {
296 notifications.extend(
297 self.create_notification(
298 UserId::from_proto(mentioned_user),
299 rpc::Notification::ChannelMessageMention {
300 message_id: message_id.to_proto(),
301 sender_id: user_id.to_proto(),
302 channel_id: channel_id.to_proto(),
303 },
304 false,
305 &*tx,
306 )
307 .await?,
308 );
309 }
310
311 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
312 .await?;
313 }
314 _ => {
315 message_id = channel_message::Entity::find()
316 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
317 .one(&*tx)
318 .await?
319 .ok_or_else(|| anyhow!("failed to insert message"))?
320 .id;
321 }
322 }
323
324 let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
325 channel_members.retain(|member| !participant_user_ids.contains(member));
326
327 Ok(CreatedChannelMessage {
328 message_id,
329 participant_connection_ids,
330 channel_members,
331 notifications,
332 })
333 })
334 .await
335 }
336
337 pub async fn observe_channel_message(
338 &self,
339 channel_id: ChannelId,
340 user_id: UserId,
341 message_id: MessageId,
342 ) -> Result<NotificationBatch> {
343 self.transaction(|tx| async move {
344 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
345 .await?;
346 let mut batch = NotificationBatch::default();
347 batch.extend(
348 self.mark_notification_as_read(
349 user_id,
350 &Notification::ChannelMessageMention {
351 message_id: message_id.to_proto(),
352 sender_id: Default::default(),
353 channel_id: Default::default(),
354 },
355 &*tx,
356 )
357 .await?,
358 );
359 Ok(batch)
360 })
361 .await
362 }
363
364 async fn observe_channel_message_internal(
365 &self,
366 channel_id: ChannelId,
367 user_id: UserId,
368 message_id: MessageId,
369 tx: &DatabaseTransaction,
370 ) -> Result<()> {
371 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
372 user_id: ActiveValue::Set(user_id),
373 channel_id: ActiveValue::Set(channel_id),
374 channel_message_id: ActiveValue::Set(message_id),
375 })
376 .on_conflict(
377 OnConflict::columns([
378 observed_channel_messages::Column::ChannelId,
379 observed_channel_messages::Column::UserId,
380 ])
381 .update_column(observed_channel_messages::Column::ChannelMessageId)
382 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
383 .to_owned(),
384 )
385 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
386 .exec_without_returning(&*tx)
387 .await?;
388 Ok(())
389 }
390
391 pub async fn latest_channel_messages(
392 &self,
393 channel_ids: &[ChannelId],
394 tx: &DatabaseTransaction,
395 ) -> Result<Vec<proto::ChannelMessageId>> {
396 let mut values = String::new();
397 for id in channel_ids {
398 if !values.is_empty() {
399 values.push_str(", ");
400 }
401 write!(&mut values, "({})", id).unwrap();
402 }
403
404 if values.is_empty() {
405 return Ok(Vec::default());
406 }
407
408 let sql = format!(
409 r#"
410 SELECT
411 *
412 FROM (
413 SELECT
414 *,
415 row_number() OVER (
416 PARTITION BY channel_id
417 ORDER BY id DESC
418 ) as row_number
419 FROM channel_messages
420 WHERE
421 channel_id in ({values})
422 ) AS messages
423 WHERE
424 row_number = 1
425 "#,
426 );
427
428 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
429 let mut last_messages = channel_message::Model::find_by_statement(stmt)
430 .stream(&*tx)
431 .await?;
432
433 let mut results = Vec::new();
434 while let Some(result) = last_messages.next().await {
435 let message = result?;
436 results.push(proto::ChannelMessageId {
437 channel_id: message.channel_id.to_proto(),
438 message_id: message.id.to_proto(),
439 });
440 }
441
442 Ok(results)
443 }
444
445 /// Removes the channel message with the given ID.
446 pub async fn remove_channel_message(
447 &self,
448 channel_id: ChannelId,
449 message_id: MessageId,
450 user_id: UserId,
451 ) -> Result<Vec<ConnectionId>> {
452 self.transaction(|tx| async move {
453 let mut rows = channel_chat_participant::Entity::find()
454 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
455 .stream(&*tx)
456 .await?;
457
458 let mut is_participant = false;
459 let mut participant_connection_ids = Vec::new();
460 while let Some(row) = rows.next().await {
461 let row = row?;
462 if row.user_id == user_id {
463 is_participant = true;
464 }
465 participant_connection_ids.push(row.connection());
466 }
467 drop(rows);
468
469 if !is_participant {
470 Err(anyhow!("not a chat participant"))?;
471 }
472
473 let result = channel_message::Entity::delete_by_id(message_id)
474 .filter(channel_message::Column::SenderId.eq(user_id))
475 .exec(&*tx)
476 .await?;
477
478 if result.rows_affected == 0 {
479 let channel = self.get_channel_internal(channel_id, &*tx).await?;
480 if self
481 .check_user_is_channel_admin(&channel, user_id, &*tx)
482 .await
483 .is_ok()
484 {
485 let result = channel_message::Entity::delete_by_id(message_id)
486 .exec(&*tx)
487 .await?;
488 if result.rows_affected == 0 {
489 Err(anyhow!("no such message"))?;
490 }
491 } else {
492 Err(anyhow!("operation could not be completed"))?;
493 }
494 }
495
496 Ok(participant_connection_ids)
497 })
498 .await
499 }
500}