1use super::*;
2use rpc::Notification;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 pub async fn join_channel_chat(
8 &self,
9 channel_id: ChannelId,
10 connection_id: ConnectionId,
11 user_id: UserId,
12 ) -> Result<()> {
13 self.transaction(|tx| async move {
14 let channel = self.get_channel_internal(channel_id, &*tx).await?;
15 self.check_user_is_channel_participant(&channel, user_id, &*tx)
16 .await?;
17 channel_chat_participant::ActiveModel {
18 id: ActiveValue::NotSet,
19 channel_id: ActiveValue::Set(channel_id),
20 user_id: ActiveValue::Set(user_id),
21 connection_id: ActiveValue::Set(connection_id.id as i32),
22 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
23 }
24 .insert(&*tx)
25 .await?;
26 Ok(())
27 })
28 .await
29 }
30
31 pub async fn channel_chat_connection_lost(
32 &self,
33 connection_id: ConnectionId,
34 tx: &DatabaseTransaction,
35 ) -> Result<()> {
36 channel_chat_participant::Entity::delete_many()
37 .filter(
38 Condition::all()
39 .add(
40 channel_chat_participant::Column::ConnectionServerId
41 .eq(connection_id.owner_id),
42 )
43 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
44 )
45 .exec(tx)
46 .await?;
47 Ok(())
48 }
49
50 pub async fn leave_channel_chat(
51 &self,
52 channel_id: ChannelId,
53 connection_id: ConnectionId,
54 _user_id: UserId,
55 ) -> Result<()> {
56 self.transaction(|tx| async move {
57 channel_chat_participant::Entity::delete_many()
58 .filter(
59 Condition::all()
60 .add(
61 channel_chat_participant::Column::ConnectionServerId
62 .eq(connection_id.owner_id),
63 )
64 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
65 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
66 )
67 .exec(&*tx)
68 .await?;
69
70 Ok(())
71 })
72 .await
73 }
74
75 pub async fn get_channel_messages(
76 &self,
77 channel_id: ChannelId,
78 user_id: UserId,
79 count: usize,
80 before_message_id: Option<MessageId>,
81 ) -> Result<Vec<proto::ChannelMessage>> {
82 self.transaction(|tx| async move {
83 let channel = self.get_channel_internal(channel_id, &*tx).await?;
84 self.check_user_is_channel_participant(&channel, user_id, &*tx)
85 .await?;
86
87 let mut condition =
88 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
89
90 if let Some(before_message_id) = before_message_id {
91 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
92 }
93
94 let rows = channel_message::Entity::find()
95 .filter(condition)
96 .order_by_desc(channel_message::Column::Id)
97 .limit(count as u64)
98 .all(&*tx)
99 .await?;
100
101 self.load_channel_messages(rows, &*tx).await
102 })
103 .await
104 }
105
106 pub async fn get_channel_messages_by_id(
107 &self,
108 user_id: UserId,
109 message_ids: &[MessageId],
110 ) -> Result<Vec<proto::ChannelMessage>> {
111 self.transaction(|tx| async move {
112 let rows = channel_message::Entity::find()
113 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
114 .order_by_desc(channel_message::Column::Id)
115 .all(&*tx)
116 .await?;
117
118 let mut channels = HashMap::<ChannelId, channel::Model>::default();
119 for row in &rows {
120 channels.insert(
121 row.channel_id,
122 self.get_channel_internal(row.channel_id, &*tx).await?,
123 );
124 }
125
126 for (_, channel) in channels {
127 self.check_user_is_channel_participant(&channel, user_id, &*tx)
128 .await?;
129 }
130
131 let messages = self.load_channel_messages(rows, &*tx).await?;
132 Ok(messages)
133 })
134 .await
135 }
136
137 async fn load_channel_messages(
138 &self,
139 rows: Vec<channel_message::Model>,
140 tx: &DatabaseTransaction,
141 ) -> Result<Vec<proto::ChannelMessage>> {
142 let mut messages = rows
143 .into_iter()
144 .map(|row| {
145 let nonce = row.nonce.as_u64_pair();
146 proto::ChannelMessage {
147 id: row.id.to_proto(),
148 sender_id: row.sender_id.to_proto(),
149 body: row.body,
150 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
151 mentions: vec![],
152 nonce: Some(proto::Nonce {
153 upper_half: nonce.0,
154 lower_half: nonce.1,
155 }),
156 }
157 })
158 .collect::<Vec<_>>();
159 messages.reverse();
160
161 let mut mentions = channel_message_mention::Entity::find()
162 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
163 .order_by_asc(channel_message_mention::Column::MessageId)
164 .order_by_asc(channel_message_mention::Column::StartOffset)
165 .stream(&*tx)
166 .await?;
167
168 let mut message_ix = 0;
169 while let Some(mention) = mentions.next().await {
170 let mention = mention?;
171 let message_id = mention.message_id.to_proto();
172 while let Some(message) = messages.get_mut(message_ix) {
173 if message.id < message_id {
174 message_ix += 1;
175 } else {
176 if message.id == message_id {
177 message.mentions.push(proto::ChatMention {
178 range: Some(proto::Range {
179 start: mention.start_offset as u64,
180 end: mention.end_offset as u64,
181 }),
182 user_id: mention.user_id.to_proto(),
183 });
184 }
185 break;
186 }
187 }
188 }
189
190 Ok(messages)
191 }
192
193 pub async fn create_channel_message(
194 &self,
195 channel_id: ChannelId,
196 user_id: UserId,
197 body: &str,
198 mentions: &[proto::ChatMention],
199 timestamp: OffsetDateTime,
200 nonce: u128,
201 ) -> Result<CreatedChannelMessage> {
202 self.transaction(|tx| async move {
203 let channel = self.get_channel_internal(channel_id, &*tx).await?;
204 self.check_user_is_channel_participant(&channel, user_id, &*tx)
205 .await?;
206
207 let mut rows = channel_chat_participant::Entity::find()
208 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
209 .stream(&*tx)
210 .await?;
211
212 let mut is_participant = false;
213 let mut participant_connection_ids = Vec::new();
214 let mut participant_user_ids = Vec::new();
215 while let Some(row) = rows.next().await {
216 let row = row?;
217 if row.user_id == user_id {
218 is_participant = true;
219 }
220 participant_user_ids.push(row.user_id);
221 participant_connection_ids.push(row.connection());
222 }
223 drop(rows);
224
225 if !is_participant {
226 Err(anyhow!("not a chat participant"))?;
227 }
228
229 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
230 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
231
232 let result = channel_message::Entity::insert(channel_message::ActiveModel {
233 channel_id: ActiveValue::Set(channel_id),
234 sender_id: ActiveValue::Set(user_id),
235 body: ActiveValue::Set(body.to_string()),
236 sent_at: ActiveValue::Set(timestamp),
237 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
238 id: ActiveValue::NotSet,
239 })
240 .on_conflict(
241 OnConflict::columns([
242 channel_message::Column::SenderId,
243 channel_message::Column::Nonce,
244 ])
245 .do_nothing()
246 .to_owned(),
247 )
248 .do_nothing()
249 .exec(&*tx)
250 .await?;
251
252 let message_id;
253 let mut notifications = Vec::new();
254 match result {
255 TryInsertResult::Inserted(result) => {
256 message_id = result.last_insert_id;
257 let mentioned_user_ids =
258 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
259
260 let mentions = mentions
261 .iter()
262 .filter_map(|mention| {
263 let range = mention.range.as_ref()?;
264 if !body.is_char_boundary(range.start as usize)
265 || !body.is_char_boundary(range.end as usize)
266 {
267 return None;
268 }
269 Some(channel_message_mention::ActiveModel {
270 message_id: ActiveValue::Set(message_id),
271 start_offset: ActiveValue::Set(range.start as i32),
272 end_offset: ActiveValue::Set(range.end as i32),
273 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
274 })
275 })
276 .collect::<Vec<_>>();
277 if !mentions.is_empty() {
278 channel_message_mention::Entity::insert_many(mentions)
279 .exec(&*tx)
280 .await?;
281 }
282
283 for mentioned_user in mentioned_user_ids {
284 notifications.extend(
285 self.create_notification(
286 UserId::from_proto(mentioned_user),
287 rpc::Notification::ChannelMessageMention {
288 message_id: message_id.to_proto(),
289 sender_id: user_id.to_proto(),
290 channel_id: channel_id.to_proto(),
291 },
292 false,
293 &*tx,
294 )
295 .await?,
296 );
297 }
298
299 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
300 .await?;
301 }
302 _ => {
303 message_id = channel_message::Entity::find()
304 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
305 .one(&*tx)
306 .await?
307 .ok_or_else(|| anyhow!("failed to insert message"))?
308 .id;
309 }
310 }
311
312 let mut channel_members = self.get_channel_participants(&channel, &*tx).await?;
313 channel_members.retain(|member| !participant_user_ids.contains(member));
314
315 Ok(CreatedChannelMessage {
316 message_id,
317 participant_connection_ids,
318 channel_members,
319 notifications,
320 })
321 })
322 .await
323 }
324
325 pub async fn observe_channel_message(
326 &self,
327 channel_id: ChannelId,
328 user_id: UserId,
329 message_id: MessageId,
330 ) -> Result<NotificationBatch> {
331 self.transaction(|tx| async move {
332 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
333 .await?;
334 let mut batch = NotificationBatch::default();
335 batch.extend(
336 self.mark_notification_as_read(
337 user_id,
338 &Notification::ChannelMessageMention {
339 message_id: message_id.to_proto(),
340 sender_id: Default::default(),
341 channel_id: Default::default(),
342 },
343 &*tx,
344 )
345 .await?,
346 );
347 Ok(batch)
348 })
349 .await
350 }
351
352 async fn observe_channel_message_internal(
353 &self,
354 channel_id: ChannelId,
355 user_id: UserId,
356 message_id: MessageId,
357 tx: &DatabaseTransaction,
358 ) -> Result<()> {
359 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
360 user_id: ActiveValue::Set(user_id),
361 channel_id: ActiveValue::Set(channel_id),
362 channel_message_id: ActiveValue::Set(message_id),
363 })
364 .on_conflict(
365 OnConflict::columns([
366 observed_channel_messages::Column::ChannelId,
367 observed_channel_messages::Column::UserId,
368 ])
369 .update_column(observed_channel_messages::Column::ChannelMessageId)
370 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
371 .to_owned(),
372 )
373 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
374 .exec_without_returning(&*tx)
375 .await?;
376 Ok(())
377 }
378
379 pub async fn unseen_channel_messages(
380 &self,
381 user_id: UserId,
382 channel_ids: &[ChannelId],
383 tx: &DatabaseTransaction,
384 ) -> Result<Vec<proto::UnseenChannelMessage>> {
385 let mut observed_messages_by_channel_id = HashMap::default();
386 let mut rows = observed_channel_messages::Entity::find()
387 .filter(observed_channel_messages::Column::UserId.eq(user_id))
388 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
389 .stream(&*tx)
390 .await?;
391
392 while let Some(row) = rows.next().await {
393 let row = row?;
394 observed_messages_by_channel_id.insert(row.channel_id, row);
395 }
396 drop(rows);
397 let mut values = String::new();
398 for id in channel_ids {
399 if !values.is_empty() {
400 values.push_str(", ");
401 }
402 write!(&mut values, "({})", id).unwrap();
403 }
404
405 if values.is_empty() {
406 return Ok(Default::default());
407 }
408
409 let sql = format!(
410 r#"
411 SELECT
412 *
413 FROM (
414 SELECT
415 *,
416 row_number() OVER (
417 PARTITION BY channel_id
418 ORDER BY id DESC
419 ) as row_number
420 FROM channel_messages
421 WHERE
422 channel_id in ({values})
423 ) AS messages
424 WHERE
425 row_number = 1
426 "#,
427 );
428
429 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
430 let last_messages = channel_message::Model::find_by_statement(stmt)
431 .all(&*tx)
432 .await?;
433
434 let mut changes = Vec::new();
435 for last_message in last_messages {
436 if let Some(observed_message) =
437 observed_messages_by_channel_id.get(&last_message.channel_id)
438 {
439 if observed_message.channel_message_id == last_message.id {
440 continue;
441 }
442 }
443 changes.push(proto::UnseenChannelMessage {
444 channel_id: last_message.channel_id.to_proto(),
445 message_id: last_message.id.to_proto(),
446 });
447 }
448
449 Ok(changes)
450 }
451
452 pub async fn remove_channel_message(
453 &self,
454 channel_id: ChannelId,
455 message_id: MessageId,
456 user_id: UserId,
457 ) -> Result<Vec<ConnectionId>> {
458 self.transaction(|tx| async move {
459 let mut rows = channel_chat_participant::Entity::find()
460 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
461 .stream(&*tx)
462 .await?;
463
464 let mut is_participant = false;
465 let mut participant_connection_ids = Vec::new();
466 while let Some(row) = rows.next().await {
467 let row = row?;
468 if row.user_id == user_id {
469 is_participant = true;
470 }
471 participant_connection_ids.push(row.connection());
472 }
473 drop(rows);
474
475 if !is_participant {
476 Err(anyhow!("not a chat participant"))?;
477 }
478
479 let result = channel_message::Entity::delete_by_id(message_id)
480 .filter(channel_message::Column::SenderId.eq(user_id))
481 .exec(&*tx)
482 .await?;
483
484 if result.rows_affected == 0 {
485 let channel = self.get_channel_internal(channel_id, &*tx).await?;
486 if self
487 .check_user_is_channel_admin(&channel, user_id, &*tx)
488 .await
489 .is_ok()
490 {
491 let result = channel_message::Entity::delete_by_id(message_id)
492 .exec(&*tx)
493 .await?;
494 if result.rows_affected == 0 {
495 Err(anyhow!("no such message"))?;
496 }
497 } else {
498 Err(anyhow!("operation could not be completed"))?;
499 }
500 }
501
502 Ok(participant_connection_ids)
503 })
504 .await
505 }
506}