1use super::*;
2use futures::Stream;
3use rpc::Notification;
4use sea_orm::TryInsertResult;
5use time::OffsetDateTime;
6
7impl Database {
8 pub async fn join_channel_chat(
9 &self,
10 channel_id: ChannelId,
11 connection_id: ConnectionId,
12 user_id: UserId,
13 ) -> Result<()> {
14 self.transaction(|tx| async move {
15 self.check_user_is_channel_member(channel_id, user_id, &*tx)
16 .await?;
17 channel_chat_participant::ActiveModel {
18 id: ActiveValue::NotSet,
19 channel_id: ActiveValue::Set(channel_id),
20 user_id: ActiveValue::Set(user_id),
21 connection_id: ActiveValue::Set(connection_id.id as i32),
22 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
23 }
24 .insert(&*tx)
25 .await?;
26 Ok(())
27 })
28 .await
29 }
30
31 pub async fn channel_chat_connection_lost(
32 &self,
33 connection_id: ConnectionId,
34 tx: &DatabaseTransaction,
35 ) -> Result<()> {
36 channel_chat_participant::Entity::delete_many()
37 .filter(
38 Condition::all()
39 .add(
40 channel_chat_participant::Column::ConnectionServerId
41 .eq(connection_id.owner_id),
42 )
43 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
44 )
45 .exec(tx)
46 .await?;
47 Ok(())
48 }
49
50 pub async fn leave_channel_chat(
51 &self,
52 channel_id: ChannelId,
53 connection_id: ConnectionId,
54 _user_id: UserId,
55 ) -> Result<()> {
56 self.transaction(|tx| async move {
57 channel_chat_participant::Entity::delete_many()
58 .filter(
59 Condition::all()
60 .add(
61 channel_chat_participant::Column::ConnectionServerId
62 .eq(connection_id.owner_id),
63 )
64 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
65 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
66 )
67 .exec(&*tx)
68 .await?;
69
70 Ok(())
71 })
72 .await
73 }
74
75 pub async fn get_channel_messages(
76 &self,
77 channel_id: ChannelId,
78 user_id: UserId,
79 count: usize,
80 before_message_id: Option<MessageId>,
81 ) -> Result<Vec<proto::ChannelMessage>> {
82 self.transaction(|tx| async move {
83 self.check_user_is_channel_member(channel_id, user_id, &*tx)
84 .await?;
85
86 let mut condition =
87 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
88
89 if let Some(before_message_id) = before_message_id {
90 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
91 }
92
93 let rows = channel_message::Entity::find()
94 .filter(condition)
95 .order_by_desc(channel_message::Column::Id)
96 .limit(count as u64)
97 .stream(&*tx)
98 .await?;
99
100 self.load_channel_messages(rows, &*tx).await
101 })
102 .await
103 }
104
105 pub async fn get_channel_messages_by_id(
106 &self,
107 user_id: UserId,
108 message_ids: &[MessageId],
109 ) -> Result<Vec<proto::ChannelMessage>> {
110 self.transaction(|tx| async move {
111 let rows = channel_message::Entity::find()
112 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
113 .order_by_desc(channel_message::Column::Id)
114 .stream(&*tx)
115 .await?;
116
117 let mut channel_ids = HashSet::<ChannelId>::default();
118 let messages = self
119 .load_channel_messages(
120 rows.map(|row| {
121 row.map(|row| {
122 channel_ids.insert(row.channel_id);
123 row
124 })
125 }),
126 &*tx,
127 )
128 .await?;
129
130 for channel_id in channel_ids {
131 self.check_user_is_channel_member(channel_id, user_id, &*tx)
132 .await?;
133 }
134
135 Ok(messages)
136 })
137 .await
138 }
139
140 async fn load_channel_messages(
141 &self,
142 mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
143 tx: &DatabaseTransaction,
144 ) -> Result<Vec<proto::ChannelMessage>> {
145 let mut messages = Vec::new();
146 while let Some(row) = rows.next().await {
147 let row = row?;
148 let nonce = row.nonce.as_u64_pair();
149 messages.push(proto::ChannelMessage {
150 id: row.id.to_proto(),
151 sender_id: row.sender_id.to_proto(),
152 body: row.body,
153 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
154 mentions: vec![],
155 nonce: Some(proto::Nonce {
156 upper_half: nonce.0,
157 lower_half: nonce.1,
158 }),
159 });
160 }
161 drop(rows);
162 messages.reverse();
163
164 let mut mentions = channel_message_mention::Entity::find()
165 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
166 .order_by_asc(channel_message_mention::Column::MessageId)
167 .order_by_asc(channel_message_mention::Column::StartOffset)
168 .stream(&*tx)
169 .await?;
170
171 let mut message_ix = 0;
172 while let Some(mention) = mentions.next().await {
173 let mention = mention?;
174 let message_id = mention.message_id.to_proto();
175 while let Some(message) = messages.get_mut(message_ix) {
176 if message.id < message_id {
177 message_ix += 1;
178 } else {
179 if message.id == message_id {
180 message.mentions.push(proto::ChatMention {
181 range: Some(proto::Range {
182 start: mention.start_offset as u64,
183 end: mention.end_offset as u64,
184 }),
185 user_id: mention.user_id.to_proto(),
186 });
187 }
188 break;
189 }
190 }
191 }
192
193 Ok(messages)
194 }
195
196 pub async fn create_channel_message(
197 &self,
198 channel_id: ChannelId,
199 user_id: UserId,
200 body: &str,
201 mentions: &[proto::ChatMention],
202 timestamp: OffsetDateTime,
203 nonce: u128,
204 ) -> Result<CreatedChannelMessage> {
205 self.transaction(|tx| async move {
206 let mut rows = channel_chat_participant::Entity::find()
207 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
208 .stream(&*tx)
209 .await?;
210
211 let mut is_participant = false;
212 let mut participant_connection_ids = Vec::new();
213 let mut participant_user_ids = Vec::new();
214 while let Some(row) = rows.next().await {
215 let row = row?;
216 if row.user_id == user_id {
217 is_participant = true;
218 }
219 participant_user_ids.push(row.user_id);
220 participant_connection_ids.push(row.connection());
221 }
222 drop(rows);
223
224 if !is_participant {
225 Err(anyhow!("not a chat participant"))?;
226 }
227
228 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
229 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
230
231 let result = channel_message::Entity::insert(channel_message::ActiveModel {
232 channel_id: ActiveValue::Set(channel_id),
233 sender_id: ActiveValue::Set(user_id),
234 body: ActiveValue::Set(body.to_string()),
235 sent_at: ActiveValue::Set(timestamp),
236 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
237 id: ActiveValue::NotSet,
238 })
239 .on_conflict(
240 OnConflict::columns([
241 channel_message::Column::SenderId,
242 channel_message::Column::Nonce,
243 ])
244 .do_nothing()
245 .to_owned(),
246 )
247 .do_nothing()
248 .exec(&*tx)
249 .await?;
250
251 let message_id;
252 let mut notifications = Vec::new();
253 match result {
254 TryInsertResult::Inserted(result) => {
255 message_id = result.last_insert_id;
256 let mentioned_user_ids =
257 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
258 let mentions = mentions
259 .iter()
260 .filter_map(|mention| {
261 let range = mention.range.as_ref()?;
262 if !body.is_char_boundary(range.start as usize)
263 || !body.is_char_boundary(range.end as usize)
264 {
265 return None;
266 }
267 Some(channel_message_mention::ActiveModel {
268 message_id: ActiveValue::Set(message_id),
269 start_offset: ActiveValue::Set(range.start as i32),
270 end_offset: ActiveValue::Set(range.end as i32),
271 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
272 })
273 })
274 .collect::<Vec<_>>();
275 if !mentions.is_empty() {
276 channel_message_mention::Entity::insert_many(mentions)
277 .exec(&*tx)
278 .await?;
279 }
280
281 for mentioned_user in mentioned_user_ids {
282 notifications.extend(
283 self.create_notification(
284 UserId::from_proto(mentioned_user),
285 rpc::Notification::ChannelMessageMention {
286 message_id: message_id.to_proto(),
287 sender_id: user_id.to_proto(),
288 channel_id: channel_id.to_proto(),
289 },
290 false,
291 &*tx,
292 )
293 .await?,
294 );
295 }
296
297 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
298 .await?;
299 }
300 _ => {
301 message_id = channel_message::Entity::find()
302 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
303 .one(&*tx)
304 .await?
305 .ok_or_else(|| anyhow!("failed to insert message"))?
306 .id;
307 }
308 }
309
310 let mut channel_members = self
311 .get_channel_participants_internal(channel_id, &*tx)
312 .await?;
313 channel_members.retain(|member| !participant_user_ids.contains(member));
314
315 Ok(CreatedChannelMessage {
316 message_id,
317 participant_connection_ids,
318 channel_members,
319 notifications,
320 })
321 })
322 .await
323 }
324
325 pub async fn observe_channel_message(
326 &self,
327 channel_id: ChannelId,
328 user_id: UserId,
329 message_id: MessageId,
330 ) -> Result<NotificationBatch> {
331 self.transaction(|tx| async move {
332 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
333 .await?;
334 let mut batch = NotificationBatch::default();
335 batch.extend(
336 self.mark_notification_as_read(
337 user_id,
338 &Notification::ChannelMessageMention {
339 message_id: message_id.to_proto(),
340 sender_id: Default::default(),
341 channel_id: Default::default(),
342 },
343 &*tx,
344 )
345 .await?,
346 );
347 Ok(batch)
348 })
349 .await
350 }
351
352 async fn observe_channel_message_internal(
353 &self,
354 channel_id: ChannelId,
355 user_id: UserId,
356 message_id: MessageId,
357 tx: &DatabaseTransaction,
358 ) -> Result<()> {
359 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
360 user_id: ActiveValue::Set(user_id),
361 channel_id: ActiveValue::Set(channel_id),
362 channel_message_id: ActiveValue::Set(message_id),
363 })
364 .on_conflict(
365 OnConflict::columns([
366 observed_channel_messages::Column::ChannelId,
367 observed_channel_messages::Column::UserId,
368 ])
369 .update_column(observed_channel_messages::Column::ChannelMessageId)
370 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
371 .to_owned(),
372 )
373 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
374 .exec_without_returning(&*tx)
375 .await?;
376 Ok(())
377 }
378
379 pub async fn unseen_channel_messages(
380 &self,
381 user_id: UserId,
382 channel_ids: &[ChannelId],
383 tx: &DatabaseTransaction,
384 ) -> Result<Vec<proto::UnseenChannelMessage>> {
385 let mut observed_messages_by_channel_id = HashMap::default();
386 let mut rows = observed_channel_messages::Entity::find()
387 .filter(observed_channel_messages::Column::UserId.eq(user_id))
388 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
389 .stream(&*tx)
390 .await?;
391
392 while let Some(row) = rows.next().await {
393 let row = row?;
394 observed_messages_by_channel_id.insert(row.channel_id, row);
395 }
396 drop(rows);
397 let mut values = String::new();
398 for id in channel_ids {
399 if !values.is_empty() {
400 values.push_str(", ");
401 }
402 write!(&mut values, "({})", id).unwrap();
403 }
404
405 if values.is_empty() {
406 return Ok(Default::default());
407 }
408
409 let sql = format!(
410 r#"
411 SELECT
412 *
413 FROM (
414 SELECT
415 *,
416 row_number() OVER (
417 PARTITION BY channel_id
418 ORDER BY id DESC
419 ) as row_number
420 FROM channel_messages
421 WHERE
422 channel_id in ({values})
423 ) AS messages
424 WHERE
425 row_number = 1
426 "#,
427 );
428
429 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
430 let last_messages = channel_message::Model::find_by_statement(stmt)
431 .all(&*tx)
432 .await?;
433
434 let mut changes = Vec::new();
435 for last_message in last_messages {
436 if let Some(observed_message) =
437 observed_messages_by_channel_id.get(&last_message.channel_id)
438 {
439 if observed_message.channel_message_id == last_message.id {
440 continue;
441 }
442 }
443 changes.push(proto::UnseenChannelMessage {
444 channel_id: last_message.channel_id.to_proto(),
445 message_id: last_message.id.to_proto(),
446 });
447 }
448
449 Ok(changes)
450 }
451
452 pub async fn remove_channel_message(
453 &self,
454 channel_id: ChannelId,
455 message_id: MessageId,
456 user_id: UserId,
457 ) -> Result<Vec<ConnectionId>> {
458 self.transaction(|tx| async move {
459 let mut rows = channel_chat_participant::Entity::find()
460 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
461 .stream(&*tx)
462 .await?;
463
464 let mut is_participant = false;
465 let mut participant_connection_ids = Vec::new();
466 while let Some(row) = rows.next().await {
467 let row = row?;
468 if row.user_id == user_id {
469 is_participant = true;
470 }
471 participant_connection_ids.push(row.connection());
472 }
473 drop(rows);
474
475 if !is_participant {
476 Err(anyhow!("not a chat participant"))?;
477 }
478
479 let result = channel_message::Entity::delete_by_id(message_id)
480 .filter(channel_message::Column::SenderId.eq(user_id))
481 .exec(&*tx)
482 .await?;
483
484 if result.rows_affected == 0 {
485 if self
486 .check_user_is_channel_admin(channel_id, user_id, &*tx)
487 .await
488 .is_ok()
489 {
490 let result = channel_message::Entity::delete_by_id(message_id)
491 .exec(&*tx)
492 .await?;
493 if result.rows_affected == 0 {
494 Err(anyhow!("no such message"))?;
495 }
496 } else {
497 Err(anyhow!("operation could not be completed"))?;
498 }
499 }
500
501 Ok(participant_connection_ids)
502 })
503 .await
504 }
505}