1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 pub async fn join_channel_chat(
8 &self,
9 channel_id: ChannelId,
10 connection_id: ConnectionId,
11 user_id: UserId,
12 ) -> Result<()> {
13 self.transaction(|tx| async move {
14 let channel = self.get_channel_internal(channel_id, &*tx).await?;
15 self.check_user_is_channel_participant(&channel, user_id, &*tx)
16 .await?;
17 channel_chat_participant::ActiveModel {
18 id: ActiveValue::NotSet,
19 channel_id: ActiveValue::Set(channel_id),
20 user_id: ActiveValue::Set(user_id),
21 connection_id: ActiveValue::Set(connection_id.id as i32),
22 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
23 }
24 .insert(&*tx)
25 .await?;
26 Ok(())
27 })
28 .await
29 }
30
31 pub async fn channel_chat_connection_lost(
32 &self,
33 connection_id: ConnectionId,
34 tx: &DatabaseTransaction,
35 ) -> Result<()> {
36 channel_chat_participant::Entity::delete_many()
37 .filter(
38 Condition::all()
39 .add(
40 channel_chat_participant::Column::ConnectionServerId
41 .eq(connection_id.owner_id),
42 )
43 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
44 )
45 .exec(tx)
46 .await?;
47 Ok(())
48 }
49
50 pub async fn leave_channel_chat(
51 &self,
52 channel_id: ChannelId,
53 connection_id: ConnectionId,
54 _user_id: UserId,
55 ) -> Result<()> {
56 self.transaction(|tx| async move {
57 channel_chat_participant::Entity::delete_many()
58 .filter(
59 Condition::all()
60 .add(
61 channel_chat_participant::Column::ConnectionServerId
62 .eq(connection_id.owner_id),
63 )
64 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
65 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
66 )
67 .exec(&*tx)
68 .await?;
69
70 Ok(())
71 })
72 .await
73 }
74
75 pub async fn get_channel_messages(
76 &self,
77 channel_id: ChannelId,
78 user_id: UserId,
79 count: usize,
80 before_message_id: Option<MessageId>,
81 ) -> Result<Vec<proto::ChannelMessage>> {
82 self.transaction(|tx| async move {
83 let channel = self.get_channel_internal(channel_id, &*tx).await?;
84 self.check_user_is_channel_participant(&channel, user_id, &*tx)
85 .await?;
86
87 let mut condition =
88 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
89
90 if let Some(before_message_id) = before_message_id {
91 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
92 }
93
94 let rows = channel_message::Entity::find()
95 .filter(condition)
96 .order_by_desc(channel_message::Column::Id)
97 .limit(count as u64)
98 .all(&*tx)
99 .await?;
100
101 self.load_channel_messages(rows, &*tx).await
102 })
103 .await
104 }
105
106 pub async fn get_channel_messages_by_id(
107 &self,
108 user_id: UserId,
109 message_ids: &[MessageId],
110 ) -> Result<Vec<proto::ChannelMessage>> {
111 self.transaction(|tx| async move {
112 let rows = channel_message::Entity::find()
113 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
114 .order_by_desc(channel_message::Column::Id)
115 .all(&*tx)
116 .await?;
117
118 let mut channels = HashMap::<ChannelId, channel::Model>::default();
119 for row in &rows {
120 channels.insert(
121 row.channel_id,
122 self.get_channel_internal(row.channel_id, &*tx).await?,
123 );
124 }
125
126 for (_, channel) in channels {
127 self.check_user_is_channel_participant(&channel, user_id, &*tx)
128 .await?;
129 }
130
131 let messages = self.load_channel_messages(rows, &*tx).await?;
132 Ok(messages)
133 })
134 .await
135 }
136
137 async fn load_channel_messages(
138 &self,
139 rows: Vec<channel_message::Model>,
140 tx: &DatabaseTransaction,
141 ) -> Result<Vec<proto::ChannelMessage>> {
142 let mut messages = rows
143 .into_iter()
144 .map(|row| {
145 let nonce = row.nonce.as_u64_pair();
146 proto::ChannelMessage {
147 id: row.id.to_proto(),
148 sender_id: row.sender_id.to_proto(),
149 body: row.body,
150 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
151 mentions: vec![],
152 nonce: Some(proto::Nonce {
153 upper_half: nonce.0,
154 lower_half: nonce.1,
155 }),
156 }
157 })
158 .collect::<Vec<_>>();
159 messages.reverse();
160
161 let mut mentions = channel_message_mention::Entity::find()
162 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
163 .order_by_asc(channel_message_mention::Column::MessageId)
164 .order_by_asc(channel_message_mention::Column::StartOffset)
165 .stream(&*tx)
166 .await?;
167
168 let mut message_ix = 0;
169 while let Some(mention) = mentions.next().await {
170 let mention = mention?;
171 let message_id = mention.message_id.to_proto();
172 while let Some(message) = messages.get_mut(message_ix) {
173 if message.id < message_id {
174 message_ix += 1;
175 } else {
176 if message.id == message_id {
177 message.mentions.push(proto::ChatMention {
178 range: Some(proto::Range {
179 start: mention.start_offset as u64,
180 end: mention.end_offset as u64,
181 }),
182 user_id: mention.user_id.to_proto(),
183 });
184 }
185 break;
186 }
187 }
188 }
189
190 Ok(messages)
191 }
192
193 pub async fn create_channel_message(
194 &self,
195 channel_id: ChannelId,
196 user_id: UserId,
197 body: &str,
198 mentions: &[proto::ChatMention],
199 timestamp: OffsetDateTime,
200 nonce: u128,
201 ) -> Result<CreatedChannelMessage> {
202 self.transaction(|tx| async move {
203 let channel = self.get_channel_internal(channel_id, &*tx).await?;
204 self.check_user_is_channel_participant(&channel, user_id, &*tx)
205 .await?;
206
207 let mut rows = channel_chat_participant::Entity::find()
208 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
209 .stream(&*tx)
210 .await?;
211
212 let mut is_participant = false;
213 let mut participant_connection_ids = Vec::new();
214 let mut participant_user_ids = Vec::new();
215 while let Some(row) = rows.next().await {
216 let row = row?;
217 if row.user_id == user_id {
218 is_participant = true;
219 }
220 participant_user_ids.push(row.user_id);
221 participant_connection_ids.push(row.connection());
222 }
223 drop(rows);
224
225 if !is_participant {
226 Err(anyhow!("not a chat participant"))?;
227 }
228
229 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
230 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
231
232 let result = channel_message::Entity::insert(channel_message::ActiveModel {
233 channel_id: ActiveValue::Set(channel_id),
234 sender_id: ActiveValue::Set(user_id),
235 body: ActiveValue::Set(body.to_string()),
236 sent_at: ActiveValue::Set(timestamp),
237 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
238 id: ActiveValue::NotSet,
239 })
240 .on_conflict(
241 OnConflict::columns([
242 channel_message::Column::SenderId,
243 channel_message::Column::Nonce,
244 ])
245 .do_nothing()
246 .to_owned(),
247 )
248 .do_nothing()
249 .exec(&*tx)
250 .await?;
251
252 let message_id;
253 let mut notifications = Vec::new();
254 match result {
255 TryInsertResult::Inserted(result) => {
256 message_id = result.last_insert_id;
257 let mentioned_user_ids =
258 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
259 let mentions = mentions
260 .iter()
261 .filter_map(|mention| {
262 let range = mention.range.as_ref()?;
263 if !body.is_char_boundary(range.start as usize)
264 || !body.is_char_boundary(range.end as usize)
265 {
266 return None;
267 }
268 Some(channel_message_mention::ActiveModel {
269 message_id: ActiveValue::Set(message_id),
270 start_offset: ActiveValue::Set(range.start as i32),
271 end_offset: ActiveValue::Set(range.end as i32),
272 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
273 })
274 })
275 .collect::<Vec<_>>();
276 if !mentions.is_empty() {
277 channel_message_mention::Entity::insert_many(mentions)
278 .exec(&*tx)
279 .await?;
280 }
281
282 for mentioned_user in mentioned_user_ids {
283 notifications.extend(
284 self.create_notification(
285 UserId::from_proto(mentioned_user),
286 rpc::Notification::ChannelMessageMention {
287 message_id: message_id.to_proto(),
288 sender_id: user_id.to_proto(),
289 channel_id: channel_id.to_proto(),
290 },
291 false,
292 &*tx,
293 )
294 .await?,
295 );
296 }
297
298 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
299 .await?;
300 }
301 _ => {
302 message_id = channel_message::Entity::find()
303 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
304 .one(&*tx)
305 .await?
306 .ok_or_else(|| anyhow!("failed to insert message"))?
307 .id;
308 }
309 }
310
311 let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
312 channel_members.retain(|member| !participant_user_ids.contains(member));
313
314 Ok(CreatedChannelMessage {
315 message_id,
316 participant_connection_ids,
317 channel_members,
318 notifications,
319 })
320 })
321 .await
322 }
323
324 pub async fn observe_channel_message(
325 &self,
326 channel_id: ChannelId,
327 user_id: UserId,
328 message_id: MessageId,
329 ) -> Result<NotificationBatch> {
330 self.transaction(|tx| async move {
331 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
332 .await?;
333 let mut batch = NotificationBatch::default();
334 batch.extend(
335 self.mark_notification_as_read(
336 user_id,
337 &Notification::ChannelMessageMention {
338 message_id: message_id.to_proto(),
339 sender_id: Default::default(),
340 channel_id: Default::default(),
341 },
342 &*tx,
343 )
344 .await?,
345 );
346 Ok(batch)
347 })
348 .await
349 }
350
351 async fn observe_channel_message_internal(
352 &self,
353 channel_id: ChannelId,
354 user_id: UserId,
355 message_id: MessageId,
356 tx: &DatabaseTransaction,
357 ) -> Result<()> {
358 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
359 user_id: ActiveValue::Set(user_id),
360 channel_id: ActiveValue::Set(channel_id),
361 channel_message_id: ActiveValue::Set(message_id),
362 })
363 .on_conflict(
364 OnConflict::columns([
365 observed_channel_messages::Column::ChannelId,
366 observed_channel_messages::Column::UserId,
367 ])
368 .update_column(observed_channel_messages::Column::ChannelMessageId)
369 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
370 .to_owned(),
371 )
372 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
373 .exec_without_returning(&*tx)
374 .await?;
375 Ok(())
376 }
377
378 pub async fn unseen_channel_messages(
379 &self,
380 user_id: UserId,
381 channel_ids: &[ChannelId],
382 tx: &DatabaseTransaction,
383 ) -> Result<Vec<proto::UnseenChannelMessage>> {
384 let mut observed_messages_by_channel_id = HashMap::default();
385 let mut rows = observed_channel_messages::Entity::find()
386 .filter(observed_channel_messages::Column::UserId.eq(user_id))
387 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
388 .stream(&*tx)
389 .await?;
390
391 while let Some(row) = rows.next().await {
392 let row = row?;
393 observed_messages_by_channel_id.insert(row.channel_id, row);
394 }
395 drop(rows);
396 let mut values = String::new();
397 for id in channel_ids {
398 if !values.is_empty() {
399 values.push_str(", ");
400 }
401 write!(&mut values, "({})", id).unwrap();
402 }
403
404 if values.is_empty() {
405 return Ok(Default::default());
406 }
407
408 let sql = format!(
409 r#"
410 SELECT
411 *
412 FROM (
413 SELECT
414 *,
415 row_number() OVER (
416 PARTITION BY channel_id
417 ORDER BY id DESC
418 ) as row_number
419 FROM channel_messages
420 WHERE
421 channel_id in ({values})
422 ) AS messages
423 WHERE
424 row_number = 1
425 "#,
426 );
427
428 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
429 let last_messages = channel_message::Model::find_by_statement(stmt)
430 .all(&*tx)
431 .await?;
432
433 let mut changes = Vec::new();
434 for last_message in last_messages {
435 if let Some(observed_message) =
436 observed_messages_by_channel_id.get(&last_message.channel_id)
437 {
438 if observed_message.channel_message_id == last_message.id {
439 continue;
440 }
441 }
442 changes.push(proto::UnseenChannelMessage {
443 channel_id: last_message.channel_id.to_proto(),
444 message_id: last_message.id.to_proto(),
445 });
446 }
447
448 Ok(changes)
449 }
450
451 pub async fn remove_channel_message(
452 &self,
453 channel_id: ChannelId,
454 message_id: MessageId,
455 user_id: UserId,
456 ) -> Result<Vec<ConnectionId>> {
457 self.transaction(|tx| async move {
458 let mut rows = channel_chat_participant::Entity::find()
459 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
460 .stream(&*tx)
461 .await?;
462
463 let mut is_participant = false;
464 let mut participant_connection_ids = Vec::new();
465 while let Some(row) = rows.next().await {
466 let row = row?;
467 if row.user_id == user_id {
468 is_participant = true;
469 }
470 participant_connection_ids.push(row.connection());
471 }
472 drop(rows);
473
474 if !is_participant {
475 Err(anyhow!("not a chat participant"))?;
476 }
477
478 let result = channel_message::Entity::delete_by_id(message_id)
479 .filter(channel_message::Column::SenderId.eq(user_id))
480 .exec(&*tx)
481 .await?;
482
483 if result.rows_affected == 0 {
484 let channel = self.get_channel_internal(channel_id, &*tx).await?;
485 if self
486 .check_user_is_channel_admin(&channel, user_id, &*tx)
487 .await
488 .is_ok()
489 {
490 let result = channel_message::Entity::delete_by_id(message_id)
491 .exec(&*tx)
492 .await?;
493 if result.rows_affected == 0 {
494 Err(anyhow!("no such message"))?;
495 }
496 } else {
497 Err(anyhow!("operation could not be completed"))?;
498 }
499 }
500
501 Ok(participant_connection_ids)
502 })
503 .await
504 }
505}