1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 /// Inserts a record representing a user joining the chat for a given channel.
8 pub async fn join_channel_chat(
9 &self,
10 channel_id: ChannelId,
11 connection_id: ConnectionId,
12 user_id: UserId,
13 ) -> Result<()> {
14 self.transaction(|tx| async move {
15 let channel = self.get_channel_internal(channel_id, &*tx).await?;
16 self.check_user_is_channel_participant(&channel, user_id, &*tx)
17 .await?;
18 channel_chat_participant::ActiveModel {
19 id: ActiveValue::NotSet,
20 channel_id: ActiveValue::Set(channel_id),
21 user_id: ActiveValue::Set(user_id),
22 connection_id: ActiveValue::Set(connection_id.id as i32),
23 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
24 }
25 .insert(&*tx)
26 .await?;
27 Ok(())
28 })
29 .await
30 }
31
32 /// Removes `channel_chat_participant` records associated with the given connection ID.
33 pub async fn channel_chat_connection_lost(
34 &self,
35 connection_id: ConnectionId,
36 tx: &DatabaseTransaction,
37 ) -> Result<()> {
38 channel_chat_participant::Entity::delete_many()
39 .filter(
40 Condition::all()
41 .add(
42 channel_chat_participant::Column::ConnectionServerId
43 .eq(connection_id.owner_id),
44 )
45 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
46 )
47 .exec(tx)
48 .await?;
49 Ok(())
50 }
51
52 /// Removes `channel_chat_participant` records associated with the given user ID so they
53 /// will no longer get chat notifications.
54 pub async fn leave_channel_chat(
55 &self,
56 channel_id: ChannelId,
57 connection_id: ConnectionId,
58 _user_id: UserId,
59 ) -> Result<()> {
60 self.transaction(|tx| async move {
61 channel_chat_participant::Entity::delete_many()
62 .filter(
63 Condition::all()
64 .add(
65 channel_chat_participant::Column::ConnectionServerId
66 .eq(connection_id.owner_id),
67 )
68 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
69 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
70 )
71 .exec(&*tx)
72 .await?;
73
74 Ok(())
75 })
76 .await
77 }
78
79 /// Retrieves the messages in the specified channel.
80 ///
81 /// Use `before_message_id` to paginate through the channel's messages.
82 pub async fn get_channel_messages(
83 &self,
84 channel_id: ChannelId,
85 user_id: UserId,
86 count: usize,
87 before_message_id: Option<MessageId>,
88 ) -> Result<Vec<proto::ChannelMessage>> {
89 self.transaction(|tx| async move {
90 let channel = self.get_channel_internal(channel_id, &*tx).await?;
91 self.check_user_is_channel_participant(&channel, user_id, &*tx)
92 .await?;
93
94 let mut condition =
95 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
96
97 if let Some(before_message_id) = before_message_id {
98 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
99 }
100
101 let rows = channel_message::Entity::find()
102 .filter(condition)
103 .order_by_desc(channel_message::Column::Id)
104 .limit(count as u64)
105 .all(&*tx)
106 .await?;
107
108 self.load_channel_messages(rows, &*tx).await
109 })
110 .await
111 }
112
113 /// Returns the channel messages with the given IDs.
114 pub async fn get_channel_messages_by_id(
115 &self,
116 user_id: UserId,
117 message_ids: &[MessageId],
118 ) -> Result<Vec<proto::ChannelMessage>> {
119 self.transaction(|tx| async move {
120 let rows = channel_message::Entity::find()
121 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122 .order_by_desc(channel_message::Column::Id)
123 .all(&*tx)
124 .await?;
125
126 let mut channels = HashMap::<ChannelId, channel::Model>::default();
127 for row in &rows {
128 channels.insert(
129 row.channel_id,
130 self.get_channel_internal(row.channel_id, &*tx).await?,
131 );
132 }
133
134 for (_, channel) in channels {
135 self.check_user_is_channel_participant(&channel, user_id, &*tx)
136 .await?;
137 }
138
139 let messages = self.load_channel_messages(rows, &*tx).await?;
140 Ok(messages)
141 })
142 .await
143 }
144
145 async fn load_channel_messages(
146 &self,
147 rows: Vec<channel_message::Model>,
148 tx: &DatabaseTransaction,
149 ) -> Result<Vec<proto::ChannelMessage>> {
150 let mut messages = rows
151 .into_iter()
152 .map(|row| {
153 let nonce = row.nonce.as_u64_pair();
154 proto::ChannelMessage {
155 id: row.id.to_proto(),
156 sender_id: row.sender_id.to_proto(),
157 body: row.body,
158 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159 mentions: vec![],
160 nonce: Some(proto::Nonce {
161 upper_half: nonce.0,
162 lower_half: nonce.1,
163 }),
164 }
165 })
166 .collect::<Vec<_>>();
167 messages.reverse();
168
169 let mut mentions = channel_message_mention::Entity::find()
170 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
171 .order_by_asc(channel_message_mention::Column::MessageId)
172 .order_by_asc(channel_message_mention::Column::StartOffset)
173 .stream(&*tx)
174 .await?;
175
176 let mut message_ix = 0;
177 while let Some(mention) = mentions.next().await {
178 let mention = mention?;
179 let message_id = mention.message_id.to_proto();
180 while let Some(message) = messages.get_mut(message_ix) {
181 if message.id < message_id {
182 message_ix += 1;
183 } else {
184 if message.id == message_id {
185 message.mentions.push(proto::ChatMention {
186 range: Some(proto::Range {
187 start: mention.start_offset as u64,
188 end: mention.end_offset as u64,
189 }),
190 user_id: mention.user_id.to_proto(),
191 });
192 }
193 break;
194 }
195 }
196 }
197
198 Ok(messages)
199 }
200
201 /// Creates a new channel message.
202 pub async fn create_channel_message(
203 &self,
204 channel_id: ChannelId,
205 user_id: UserId,
206 body: &str,
207 mentions: &[proto::ChatMention],
208 timestamp: OffsetDateTime,
209 nonce: u128,
210 ) -> Result<CreatedChannelMessage> {
211 self.transaction(|tx| async move {
212 let channel = self.get_channel_internal(channel_id, &*tx).await?;
213 self.check_user_is_channel_participant(&channel, user_id, &*tx)
214 .await?;
215
216 let mut rows = channel_chat_participant::Entity::find()
217 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
218 .stream(&*tx)
219 .await?;
220
221 let mut is_participant = false;
222 let mut participant_connection_ids = Vec::new();
223 let mut participant_user_ids = Vec::new();
224 while let Some(row) = rows.next().await {
225 let row = row?;
226 if row.user_id == user_id {
227 is_participant = true;
228 }
229 participant_user_ids.push(row.user_id);
230 participant_connection_ids.push(row.connection());
231 }
232 drop(rows);
233
234 if !is_participant {
235 Err(anyhow!("not a chat participant"))?;
236 }
237
238 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
239 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
240
241 let result = channel_message::Entity::insert(channel_message::ActiveModel {
242 channel_id: ActiveValue::Set(channel_id),
243 sender_id: ActiveValue::Set(user_id),
244 body: ActiveValue::Set(body.to_string()),
245 sent_at: ActiveValue::Set(timestamp),
246 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
247 id: ActiveValue::NotSet,
248 })
249 .on_conflict(
250 OnConflict::columns([
251 channel_message::Column::SenderId,
252 channel_message::Column::Nonce,
253 ])
254 .do_nothing()
255 .to_owned(),
256 )
257 .do_nothing()
258 .exec(&*tx)
259 .await?;
260
261 let message_id;
262 let mut notifications = Vec::new();
263 match result {
264 TryInsertResult::Inserted(result) => {
265 message_id = result.last_insert_id;
266 let mentioned_user_ids =
267 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
268
269 let mentions = mentions
270 .iter()
271 .filter_map(|mention| {
272 let range = mention.range.as_ref()?;
273 if !body.is_char_boundary(range.start as usize)
274 || !body.is_char_boundary(range.end as usize)
275 {
276 return None;
277 }
278 Some(channel_message_mention::ActiveModel {
279 message_id: ActiveValue::Set(message_id),
280 start_offset: ActiveValue::Set(range.start as i32),
281 end_offset: ActiveValue::Set(range.end as i32),
282 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
283 })
284 })
285 .collect::<Vec<_>>();
286 if !mentions.is_empty() {
287 channel_message_mention::Entity::insert_many(mentions)
288 .exec(&*tx)
289 .await?;
290 }
291
292 for mentioned_user in mentioned_user_ids {
293 notifications.extend(
294 self.create_notification(
295 UserId::from_proto(mentioned_user),
296 rpc::Notification::ChannelMessageMention {
297 message_id: message_id.to_proto(),
298 sender_id: user_id.to_proto(),
299 channel_id: channel_id.to_proto(),
300 },
301 false,
302 &*tx,
303 )
304 .await?,
305 );
306 }
307
308 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
309 .await?;
310 }
311 _ => {
312 message_id = channel_message::Entity::find()
313 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
314 .one(&*tx)
315 .await?
316 .ok_or_else(|| anyhow!("failed to insert message"))?
317 .id;
318 }
319 }
320
321 let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
322 channel_members.retain(|member| !participant_user_ids.contains(member));
323
324 Ok(CreatedChannelMessage {
325 message_id,
326 participant_connection_ids,
327 channel_members,
328 notifications,
329 })
330 })
331 .await
332 }
333
334 pub async fn observe_channel_message(
335 &self,
336 channel_id: ChannelId,
337 user_id: UserId,
338 message_id: MessageId,
339 ) -> Result<NotificationBatch> {
340 self.transaction(|tx| async move {
341 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
342 .await?;
343 let mut batch = NotificationBatch::default();
344 batch.extend(
345 self.mark_notification_as_read(
346 user_id,
347 &Notification::ChannelMessageMention {
348 message_id: message_id.to_proto(),
349 sender_id: Default::default(),
350 channel_id: Default::default(),
351 },
352 &*tx,
353 )
354 .await?,
355 );
356 Ok(batch)
357 })
358 .await
359 }
360
361 async fn observe_channel_message_internal(
362 &self,
363 channel_id: ChannelId,
364 user_id: UserId,
365 message_id: MessageId,
366 tx: &DatabaseTransaction,
367 ) -> Result<()> {
368 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
369 user_id: ActiveValue::Set(user_id),
370 channel_id: ActiveValue::Set(channel_id),
371 channel_message_id: ActiveValue::Set(message_id),
372 })
373 .on_conflict(
374 OnConflict::columns([
375 observed_channel_messages::Column::ChannelId,
376 observed_channel_messages::Column::UserId,
377 ])
378 .update_column(observed_channel_messages::Column::ChannelMessageId)
379 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
380 .to_owned(),
381 )
382 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
383 .exec_without_returning(&*tx)
384 .await?;
385 Ok(())
386 }
387
388 /// Returns the unseen messages for the given user in the specified channels.
389 pub async fn unseen_channel_messages(
390 &self,
391 user_id: UserId,
392 channel_ids: &[ChannelId],
393 tx: &DatabaseTransaction,
394 ) -> Result<Vec<proto::UnseenChannelMessage>> {
395 let mut observed_messages_by_channel_id = HashMap::default();
396 let mut rows = observed_channel_messages::Entity::find()
397 .filter(observed_channel_messages::Column::UserId.eq(user_id))
398 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
399 .stream(&*tx)
400 .await?;
401
402 while let Some(row) = rows.next().await {
403 let row = row?;
404 observed_messages_by_channel_id.insert(row.channel_id, row);
405 }
406 drop(rows);
407 let mut values = String::new();
408 for id in channel_ids {
409 if !values.is_empty() {
410 values.push_str(", ");
411 }
412 write!(&mut values, "({})", id).unwrap();
413 }
414
415 if values.is_empty() {
416 return Ok(Default::default());
417 }
418
419 let sql = format!(
420 r#"
421 SELECT
422 *
423 FROM (
424 SELECT
425 *,
426 row_number() OVER (
427 PARTITION BY channel_id
428 ORDER BY id DESC
429 ) as row_number
430 FROM channel_messages
431 WHERE
432 channel_id in ({values})
433 ) AS messages
434 WHERE
435 row_number = 1
436 "#,
437 );
438
439 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
440 let last_messages = channel_message::Model::find_by_statement(stmt)
441 .all(&*tx)
442 .await?;
443
444 let mut changes = Vec::new();
445 for last_message in last_messages {
446 if let Some(observed_message) =
447 observed_messages_by_channel_id.get(&last_message.channel_id)
448 {
449 if observed_message.channel_message_id == last_message.id {
450 continue;
451 }
452 }
453 changes.push(proto::UnseenChannelMessage {
454 channel_id: last_message.channel_id.to_proto(),
455 message_id: last_message.id.to_proto(),
456 });
457 }
458
459 Ok(changes)
460 }
461
462 /// Removes the channel message with the given ID.
463 pub async fn remove_channel_message(
464 &self,
465 channel_id: ChannelId,
466 message_id: MessageId,
467 user_id: UserId,
468 ) -> Result<Vec<ConnectionId>> {
469 self.transaction(|tx| async move {
470 let mut rows = channel_chat_participant::Entity::find()
471 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
472 .stream(&*tx)
473 .await?;
474
475 let mut is_participant = false;
476 let mut participant_connection_ids = Vec::new();
477 while let Some(row) = rows.next().await {
478 let row = row?;
479 if row.user_id == user_id {
480 is_participant = true;
481 }
482 participant_connection_ids.push(row.connection());
483 }
484 drop(rows);
485
486 if !is_participant {
487 Err(anyhow!("not a chat participant"))?;
488 }
489
490 let result = channel_message::Entity::delete_by_id(message_id)
491 .filter(channel_message::Column::SenderId.eq(user_id))
492 .exec(&*tx)
493 .await?;
494
495 if result.rows_affected == 0 {
496 let channel = self.get_channel_internal(channel_id, &*tx).await?;
497 if self
498 .check_user_is_channel_admin(&channel, user_id, &*tx)
499 .await
500 .is_ok()
501 {
502 let result = channel_message::Entity::delete_by_id(message_id)
503 .exec(&*tx)
504 .await?;
505 if result.rows_affected == 0 {
506 Err(anyhow!("no such message"))?;
507 }
508 } else {
509 Err(anyhow!("operation could not be completed"))?;
510 }
511 }
512
513 Ok(participant_connection_ids)
514 })
515 .await
516 }
517}