1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 /// Inserts a record representing a user joining the chat for a given channel.
8 pub async fn join_channel_chat(
9 &self,
10 channel_id: ChannelId,
11 connection_id: ConnectionId,
12 user_id: UserId,
13 ) -> Result<()> {
14 self.transaction(|tx| async move {
15 let channel = self.get_channel_internal(channel_id, &*tx).await?;
16 self.check_user_is_channel_participant(&channel, user_id, &*tx)
17 .await?;
18 channel_chat_participant::ActiveModel {
19 id: ActiveValue::NotSet,
20 channel_id: ActiveValue::Set(channel_id),
21 user_id: ActiveValue::Set(user_id),
22 connection_id: ActiveValue::Set(connection_id.id as i32),
23 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
24 }
25 .insert(&*tx)
26 .await?;
27 Ok(())
28 })
29 .await
30 }
31
32 /// Removes `channel_chat_participant` records associated with the given connection ID.
33 pub async fn channel_chat_connection_lost(
34 &self,
35 connection_id: ConnectionId,
36 tx: &DatabaseTransaction,
37 ) -> Result<()> {
38 channel_chat_participant::Entity::delete_many()
39 .filter(
40 Condition::all()
41 .add(
42 channel_chat_participant::Column::ConnectionServerId
43 .eq(connection_id.owner_id),
44 )
45 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
46 )
47 .exec(tx)
48 .await?;
49 Ok(())
50 }
51
52 /// Removes `channel_chat_participant` records associated with the given user ID so they
53 /// will no longer get chat notifications.
54 pub async fn leave_channel_chat(
55 &self,
56 channel_id: ChannelId,
57 connection_id: ConnectionId,
58 _user_id: UserId,
59 ) -> Result<()> {
60 self.transaction(|tx| async move {
61 channel_chat_participant::Entity::delete_many()
62 .filter(
63 Condition::all()
64 .add(
65 channel_chat_participant::Column::ConnectionServerId
66 .eq(connection_id.owner_id),
67 )
68 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
69 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
70 )
71 .exec(&*tx)
72 .await?;
73
74 Ok(())
75 })
76 .await
77 }
78
79 /// Retrieves the messages in the specified channel.
80 ///
81 /// Use `before_message_id` to paginate through the channel's messages.
82 pub async fn get_channel_messages(
83 &self,
84 channel_id: ChannelId,
85 user_id: UserId,
86 count: usize,
87 before_message_id: Option<MessageId>,
88 ) -> Result<Vec<proto::ChannelMessage>> {
89 self.transaction(|tx| async move {
90 let channel = self.get_channel_internal(channel_id, &*tx).await?;
91 self.check_user_is_channel_participant(&channel, user_id, &*tx)
92 .await?;
93
94 let mut condition =
95 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
96
97 if let Some(before_message_id) = before_message_id {
98 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
99 }
100
101 let rows = channel_message::Entity::find()
102 .filter(condition)
103 .order_by_desc(channel_message::Column::Id)
104 .limit(count as u64)
105 .all(&*tx)
106 .await?;
107
108 self.load_channel_messages(rows, &*tx).await
109 })
110 .await
111 }
112
113 /// Returns the channel messages with the given IDs.
114 pub async fn get_channel_messages_by_id(
115 &self,
116 user_id: UserId,
117 message_ids: &[MessageId],
118 ) -> Result<Vec<proto::ChannelMessage>> {
119 self.transaction(|tx| async move {
120 let rows = channel_message::Entity::find()
121 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
122 .order_by_desc(channel_message::Column::Id)
123 .all(&*tx)
124 .await?;
125
126 let mut channels = HashMap::<ChannelId, channel::Model>::default();
127 for row in &rows {
128 channels.insert(
129 row.channel_id,
130 self.get_channel_internal(row.channel_id, &*tx).await?,
131 );
132 }
133
134 for (_, channel) in channels {
135 self.check_user_is_channel_participant(&channel, user_id, &*tx)
136 .await?;
137 }
138
139 let messages = self.load_channel_messages(rows, &*tx).await?;
140 Ok(messages)
141 })
142 .await
143 }
144
145 async fn load_channel_messages(
146 &self,
147 rows: Vec<channel_message::Model>,
148 tx: &DatabaseTransaction,
149 ) -> Result<Vec<proto::ChannelMessage>> {
150 let mut messages = rows
151 .into_iter()
152 .map(|row| {
153 let nonce = row.nonce.as_u64_pair();
154 proto::ChannelMessage {
155 id: row.id.to_proto(),
156 sender_id: row.sender_id.to_proto(),
157 body: row.body,
158 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
159 mentions: vec![],
160 nonce: Some(proto::Nonce {
161 upper_half: nonce.0,
162 lower_half: nonce.1,
163 }),
164 }
165 })
166 .collect::<Vec<_>>();
167 messages.reverse();
168
169 let mut mentions = channel_message_mention::Entity::find()
170 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
171 .order_by_asc(channel_message_mention::Column::MessageId)
172 .order_by_asc(channel_message_mention::Column::StartOffset)
173 .stream(&*tx)
174 .await?;
175
176 let mut message_ix = 0;
177 while let Some(mention) = mentions.next().await {
178 let mention = mention?;
179 let message_id = mention.message_id.to_proto();
180 while let Some(message) = messages.get_mut(message_ix) {
181 if message.id < message_id {
182 message_ix += 1;
183 } else {
184 if message.id == message_id {
185 message.mentions.push(proto::ChatMention {
186 range: Some(proto::Range {
187 start: mention.start_offset as u64,
188 end: mention.end_offset as u64,
189 }),
190 user_id: mention.user_id.to_proto(),
191 });
192 }
193 break;
194 }
195 }
196 }
197
198 Ok(messages)
199 }
200
201 /// Creates a new channel message.
202 pub async fn create_channel_message(
203 &self,
204 channel_id: ChannelId,
205 user_id: UserId,
206 body: &str,
207 mentions: &[proto::ChatMention],
208 timestamp: OffsetDateTime,
209 nonce: u128,
210 ) -> Result<CreatedChannelMessage> {
211 self.transaction(|tx| async move {
212 let channel = self.get_channel_internal(channel_id, &*tx).await?;
213 self.check_user_is_channel_participant(&channel, user_id, &*tx)
214 .await?;
215
216 let mut rows = channel_chat_participant::Entity::find()
217 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
218 .stream(&*tx)
219 .await?;
220
221 let mut is_participant = false;
222 let mut participant_connection_ids = Vec::new();
223 let mut participant_user_ids = Vec::new();
224 while let Some(row) = rows.next().await {
225 let row = row?;
226 if row.user_id == user_id {
227 is_participant = true;
228 }
229 participant_user_ids.push(row.user_id);
230 participant_connection_ids.push(row.connection());
231 }
232 drop(rows);
233
234 if !is_participant {
235 Err(anyhow!("not a chat participant"))?;
236 }
237
238 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
239 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
240
241 let result = channel_message::Entity::insert(channel_message::ActiveModel {
242 channel_id: ActiveValue::Set(channel_id),
243 sender_id: ActiveValue::Set(user_id),
244 body: ActiveValue::Set(body.to_string()),
245 sent_at: ActiveValue::Set(timestamp),
246 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
247 id: ActiveValue::NotSet,
248 })
249 .on_conflict(
250 OnConflict::columns([
251 channel_message::Column::SenderId,
252 channel_message::Column::Nonce,
253 ])
254 .do_nothing()
255 .to_owned(),
256 )
257 .do_nothing()
258 .exec(&*tx)
259 .await?;
260
261 let message_id;
262 let mut notifications = Vec::new();
263 match result {
264 TryInsertResult::Inserted(result) => {
265 message_id = result.last_insert_id;
266 let mentioned_user_ids =
267 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
268
269 let mentions = mentions
270 .iter()
271 .filter_map(|mention| {
272 let range = mention.range.as_ref()?;
273 if !body.is_char_boundary(range.start as usize)
274 || !body.is_char_boundary(range.end as usize)
275 {
276 return None;
277 }
278 Some(channel_message_mention::ActiveModel {
279 message_id: ActiveValue::Set(message_id),
280 start_offset: ActiveValue::Set(range.start as i32),
281 end_offset: ActiveValue::Set(range.end as i32),
282 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
283 })
284 })
285 .collect::<Vec<_>>();
286 if !mentions.is_empty() {
287 channel_message_mention::Entity::insert_many(mentions)
288 .exec(&*tx)
289 .await?;
290 }
291
292 for mentioned_user in mentioned_user_ids {
293 notifications.extend(
294 self.create_notification(
295 UserId::from_proto(mentioned_user),
296 rpc::Notification::ChannelMessageMention {
297 message_id: message_id.to_proto(),
298 sender_id: user_id.to_proto(),
299 channel_id: channel_id.to_proto(),
300 },
301 false,
302 &*tx,
303 )
304 .await?,
305 );
306 }
307
308 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
309 .await?;
310 }
311 _ => {
312 message_id = channel_message::Entity::find()
313 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
314 .one(&*tx)
315 .await?
316 .ok_or_else(|| anyhow!("failed to insert message"))?
317 .id;
318 }
319 }
320
321 let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
322 channel_members.retain(|member| !participant_user_ids.contains(member));
323
324 Ok(CreatedChannelMessage {
325 message_id,
326 participant_connection_ids,
327 channel_members,
328 notifications,
329 })
330 })
331 .await
332 }
333
334 pub async fn observe_channel_message(
335 &self,
336 channel_id: ChannelId,
337 user_id: UserId,
338 message_id: MessageId,
339 ) -> Result<NotificationBatch> {
340 self.transaction(|tx| async move {
341 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
342 .await?;
343 let mut batch = NotificationBatch::default();
344 batch.extend(
345 self.mark_notification_as_read(
346 user_id,
347 &Notification::ChannelMessageMention {
348 message_id: message_id.to_proto(),
349 sender_id: Default::default(),
350 channel_id: Default::default(),
351 },
352 &*tx,
353 )
354 .await?,
355 );
356 Ok(batch)
357 })
358 .await
359 }
360
361 async fn observe_channel_message_internal(
362 &self,
363 channel_id: ChannelId,
364 user_id: UserId,
365 message_id: MessageId,
366 tx: &DatabaseTransaction,
367 ) -> Result<()> {
368 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
369 user_id: ActiveValue::Set(user_id),
370 channel_id: ActiveValue::Set(channel_id),
371 channel_message_id: ActiveValue::Set(message_id),
372 })
373 .on_conflict(
374 OnConflict::columns([
375 observed_channel_messages::Column::ChannelId,
376 observed_channel_messages::Column::UserId,
377 ])
378 .update_column(observed_channel_messages::Column::ChannelMessageId)
379 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
380 .to_owned(),
381 )
382 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
383 .exec_without_returning(&*tx)
384 .await?;
385 Ok(())
386 }
387
388 pub async fn latest_channel_messages(
389 &self,
390 channel_ids: &[ChannelId],
391 tx: &DatabaseTransaction,
392 ) -> Result<Vec<proto::ChannelMessageId>> {
393 let mut values = String::new();
394 for id in channel_ids {
395 if !values.is_empty() {
396 values.push_str(", ");
397 }
398 write!(&mut values, "({})", id).unwrap();
399 }
400
401 if values.is_empty() {
402 return Ok(Vec::default());
403 }
404
405 let sql = format!(
406 r#"
407 SELECT
408 *
409 FROM (
410 SELECT
411 *,
412 row_number() OVER (
413 PARTITION BY channel_id
414 ORDER BY id DESC
415 ) as row_number
416 FROM channel_messages
417 WHERE
418 channel_id in ({values})
419 ) AS messages
420 WHERE
421 row_number = 1
422 "#,
423 );
424
425 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
426 let mut last_messages = channel_message::Model::find_by_statement(stmt)
427 .stream(&*tx)
428 .await?;
429
430 let mut results = Vec::new();
431 while let Some(result) = last_messages.next().await {
432 let message = result?;
433 results.push(proto::ChannelMessageId {
434 channel_id: message.channel_id.to_proto(),
435 message_id: message.id.to_proto(),
436 });
437 }
438
439 Ok(results)
440 }
441
442 /// Removes the channel message with the given ID.
443 pub async fn remove_channel_message(
444 &self,
445 channel_id: ChannelId,
446 message_id: MessageId,
447 user_id: UserId,
448 ) -> Result<Vec<ConnectionId>> {
449 self.transaction(|tx| async move {
450 let mut rows = channel_chat_participant::Entity::find()
451 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
452 .stream(&*tx)
453 .await?;
454
455 let mut is_participant = false;
456 let mut participant_connection_ids = Vec::new();
457 while let Some(row) = rows.next().await {
458 let row = row?;
459 if row.user_id == user_id {
460 is_participant = true;
461 }
462 participant_connection_ids.push(row.connection());
463 }
464 drop(rows);
465
466 if !is_participant {
467 Err(anyhow!("not a chat participant"))?;
468 }
469
470 let result = channel_message::Entity::delete_by_id(message_id)
471 .filter(channel_message::Column::SenderId.eq(user_id))
472 .exec(&*tx)
473 .await?;
474
475 if result.rows_affected == 0 {
476 let channel = self.get_channel_internal(channel_id, &*tx).await?;
477 if self
478 .check_user_is_channel_admin(&channel, user_id, &*tx)
479 .await
480 .is_ok()
481 {
482 let result = channel_message::Entity::delete_by_id(message_id)
483 .exec(&*tx)
484 .await?;
485 if result.rows_affected == 0 {
486 Err(anyhow!("no such message"))?;
487 }
488 } else {
489 Err(anyhow!("operation could not be completed"))?;
490 }
491 }
492
493 Ok(participant_connection_ids)
494 })
495 .await
496 }
497}