1use super::*;
2use futures::Stream;
3use sea_orm::TryInsertResult;
4use time::OffsetDateTime;
5
6impl Database {
7 pub async fn join_channel_chat(
8 &self,
9 channel_id: ChannelId,
10 connection_id: ConnectionId,
11 user_id: UserId,
12 ) -> Result<()> {
13 self.transaction(|tx| async move {
14 self.check_user_is_channel_member(channel_id, user_id, &*tx)
15 .await?;
16 channel_chat_participant::ActiveModel {
17 id: ActiveValue::NotSet,
18 channel_id: ActiveValue::Set(channel_id),
19 user_id: ActiveValue::Set(user_id),
20 connection_id: ActiveValue::Set(connection_id.id as i32),
21 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
22 }
23 .insert(&*tx)
24 .await?;
25 Ok(())
26 })
27 .await
28 }
29
30 pub async fn channel_chat_connection_lost(
31 &self,
32 connection_id: ConnectionId,
33 tx: &DatabaseTransaction,
34 ) -> Result<()> {
35 channel_chat_participant::Entity::delete_many()
36 .filter(
37 Condition::all()
38 .add(
39 channel_chat_participant::Column::ConnectionServerId
40 .eq(connection_id.owner_id),
41 )
42 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
43 )
44 .exec(tx)
45 .await?;
46 Ok(())
47 }
48
49 pub async fn leave_channel_chat(
50 &self,
51 channel_id: ChannelId,
52 connection_id: ConnectionId,
53 _user_id: UserId,
54 ) -> Result<()> {
55 self.transaction(|tx| async move {
56 channel_chat_participant::Entity::delete_many()
57 .filter(
58 Condition::all()
59 .add(
60 channel_chat_participant::Column::ConnectionServerId
61 .eq(connection_id.owner_id),
62 )
63 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
64 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
65 )
66 .exec(&*tx)
67 .await?;
68
69 Ok(())
70 })
71 .await
72 }
73
74 pub async fn get_channel_messages(
75 &self,
76 channel_id: ChannelId,
77 user_id: UserId,
78 count: usize,
79 before_message_id: Option<MessageId>,
80 ) -> Result<Vec<proto::ChannelMessage>> {
81 self.transaction(|tx| async move {
82 self.check_user_is_channel_member(channel_id, user_id, &*tx)
83 .await?;
84
85 let mut condition =
86 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
87
88 if let Some(before_message_id) = before_message_id {
89 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
90 }
91
92 let rows = channel_message::Entity::find()
93 .filter(condition)
94 .order_by_desc(channel_message::Column::Id)
95 .limit(count as u64)
96 .stream(&*tx)
97 .await?;
98
99 self.load_channel_messages(rows, &*tx).await
100 })
101 .await
102 }
103
104 pub async fn get_channel_messages_by_id(
105 &self,
106 user_id: UserId,
107 message_ids: &[MessageId],
108 ) -> Result<Vec<proto::ChannelMessage>> {
109 self.transaction(|tx| async move {
110 let rows = channel_message::Entity::find()
111 .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
112 .order_by_desc(channel_message::Column::Id)
113 .stream(&*tx)
114 .await?;
115
116 let mut channel_ids = HashSet::<ChannelId>::default();
117 let messages = self
118 .load_channel_messages(
119 rows.map(|row| {
120 row.map(|row| {
121 channel_ids.insert(row.channel_id);
122 row
123 })
124 }),
125 &*tx,
126 )
127 .await?;
128
129 for channel_id in channel_ids {
130 self.check_user_is_channel_member(channel_id, user_id, &*tx)
131 .await?;
132 }
133
134 Ok(messages)
135 })
136 .await
137 }
138
139 async fn load_channel_messages(
140 &self,
141 mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
142 tx: &DatabaseTransaction,
143 ) -> Result<Vec<proto::ChannelMessage>> {
144 let mut messages = Vec::new();
145 while let Some(row) = rows.next().await {
146 let row = row?;
147 let nonce = row.nonce.as_u64_pair();
148 messages.push(proto::ChannelMessage {
149 id: row.id.to_proto(),
150 sender_id: row.sender_id.to_proto(),
151 body: row.body,
152 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
153 mentions: vec![],
154 nonce: Some(proto::Nonce {
155 upper_half: nonce.0,
156 lower_half: nonce.1,
157 }),
158 });
159 }
160 drop(rows);
161 messages.reverse();
162
163 let mut mentions = channel_message_mention::Entity::find()
164 .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
165 .order_by_asc(channel_message_mention::Column::MessageId)
166 .order_by_asc(channel_message_mention::Column::StartOffset)
167 .stream(&*tx)
168 .await?;
169
170 let mut message_ix = 0;
171 while let Some(mention) = mentions.next().await {
172 let mention = mention?;
173 let message_id = mention.message_id.to_proto();
174 while let Some(message) = messages.get_mut(message_ix) {
175 if message.id < message_id {
176 message_ix += 1;
177 } else {
178 if message.id == message_id {
179 message.mentions.push(proto::ChatMention {
180 range: Some(proto::Range {
181 start: mention.start_offset as u64,
182 end: mention.end_offset as u64,
183 }),
184 user_id: mention.user_id.to_proto(),
185 });
186 }
187 break;
188 }
189 }
190 }
191
192 Ok(messages)
193 }
194
195 pub async fn create_channel_message(
196 &self,
197 channel_id: ChannelId,
198 user_id: UserId,
199 body: &str,
200 mentions: &[proto::ChatMention],
201 timestamp: OffsetDateTime,
202 nonce: u128,
203 ) -> Result<CreatedChannelMessage> {
204 self.transaction(|tx| async move {
205 let mut rows = channel_chat_participant::Entity::find()
206 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
207 .stream(&*tx)
208 .await?;
209
210 let mut is_participant = false;
211 let mut participant_connection_ids = Vec::new();
212 let mut participant_user_ids = Vec::new();
213 while let Some(row) = rows.next().await {
214 let row = row?;
215 if row.user_id == user_id {
216 is_participant = true;
217 }
218 participant_user_ids.push(row.user_id);
219 participant_connection_ids.push(row.connection());
220 }
221 drop(rows);
222
223 if !is_participant {
224 Err(anyhow!("not a chat participant"))?;
225 }
226
227 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
228 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
229
230 let result = channel_message::Entity::insert(channel_message::ActiveModel {
231 channel_id: ActiveValue::Set(channel_id),
232 sender_id: ActiveValue::Set(user_id),
233 body: ActiveValue::Set(body.to_string()),
234 sent_at: ActiveValue::Set(timestamp),
235 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
236 id: ActiveValue::NotSet,
237 })
238 .on_conflict(
239 OnConflict::columns([
240 channel_message::Column::SenderId,
241 channel_message::Column::Nonce,
242 ])
243 .do_nothing()
244 .to_owned(),
245 )
246 .do_nothing()
247 .exec(&*tx)
248 .await?;
249
250 let message_id;
251 let mut notifications = Vec::new();
252 match result {
253 TryInsertResult::Inserted(result) => {
254 message_id = result.last_insert_id;
255 let mentioned_user_ids =
256 mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
257 let mentions = mentions
258 .iter()
259 .filter_map(|mention| {
260 let range = mention.range.as_ref()?;
261 if !body.is_char_boundary(range.start as usize)
262 || !body.is_char_boundary(range.end as usize)
263 {
264 return None;
265 }
266 Some(channel_message_mention::ActiveModel {
267 message_id: ActiveValue::Set(message_id),
268 start_offset: ActiveValue::Set(range.start as i32),
269 end_offset: ActiveValue::Set(range.end as i32),
270 user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
271 })
272 })
273 .collect::<Vec<_>>();
274 if !mentions.is_empty() {
275 channel_message_mention::Entity::insert_many(mentions)
276 .exec(&*tx)
277 .await?;
278 }
279
280 for mentioned_user in mentioned_user_ids {
281 notifications.extend(
282 self.create_notification(
283 UserId::from_proto(mentioned_user),
284 rpc::Notification::ChannelMessageMention {
285 message_id: message_id.to_proto(),
286 sender_id: user_id.to_proto(),
287 channel_id: channel_id.to_proto(),
288 },
289 false,
290 &*tx,
291 )
292 .await?,
293 );
294 }
295
296 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
297 .await?;
298 }
299 _ => {
300 message_id = channel_message::Entity::find()
301 .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
302 .one(&*tx)
303 .await?
304 .ok_or_else(|| anyhow!("failed to insert message"))?
305 .id;
306 }
307 }
308
309 let mut channel_members = self
310 .get_channel_participants_internal(channel_id, &*tx)
311 .await?;
312 channel_members.retain(|member| !participant_user_ids.contains(member));
313
314 Ok(CreatedChannelMessage {
315 message_id,
316 participant_connection_ids,
317 channel_members,
318 notifications,
319 })
320 })
321 .await
322 }
323
324 pub async fn observe_channel_message(
325 &self,
326 channel_id: ChannelId,
327 user_id: UserId,
328 message_id: MessageId,
329 ) -> Result<()> {
330 self.transaction(|tx| async move {
331 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
332 .await?;
333 Ok(())
334 })
335 .await
336 }
337
338 async fn observe_channel_message_internal(
339 &self,
340 channel_id: ChannelId,
341 user_id: UserId,
342 message_id: MessageId,
343 tx: &DatabaseTransaction,
344 ) -> Result<()> {
345 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
346 user_id: ActiveValue::Set(user_id),
347 channel_id: ActiveValue::Set(channel_id),
348 channel_message_id: ActiveValue::Set(message_id),
349 })
350 .on_conflict(
351 OnConflict::columns([
352 observed_channel_messages::Column::ChannelId,
353 observed_channel_messages::Column::UserId,
354 ])
355 .update_column(observed_channel_messages::Column::ChannelMessageId)
356 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
357 .to_owned(),
358 )
359 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
360 .exec_without_returning(&*tx)
361 .await?;
362 Ok(())
363 }
364
365 pub async fn unseen_channel_messages(
366 &self,
367 user_id: UserId,
368 channel_ids: &[ChannelId],
369 tx: &DatabaseTransaction,
370 ) -> Result<Vec<proto::UnseenChannelMessage>> {
371 let mut observed_messages_by_channel_id = HashMap::default();
372 let mut rows = observed_channel_messages::Entity::find()
373 .filter(observed_channel_messages::Column::UserId.eq(user_id))
374 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
375 .stream(&*tx)
376 .await?;
377
378 while let Some(row) = rows.next().await {
379 let row = row?;
380 observed_messages_by_channel_id.insert(row.channel_id, row);
381 }
382 drop(rows);
383 let mut values = String::new();
384 for id in channel_ids {
385 if !values.is_empty() {
386 values.push_str(", ");
387 }
388 write!(&mut values, "({})", id).unwrap();
389 }
390
391 if values.is_empty() {
392 return Ok(Default::default());
393 }
394
395 let sql = format!(
396 r#"
397 SELECT
398 *
399 FROM (
400 SELECT
401 *,
402 row_number() OVER (
403 PARTITION BY channel_id
404 ORDER BY id DESC
405 ) as row_number
406 FROM channel_messages
407 WHERE
408 channel_id in ({values})
409 ) AS messages
410 WHERE
411 row_number = 1
412 "#,
413 );
414
415 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
416 let last_messages = channel_message::Model::find_by_statement(stmt)
417 .all(&*tx)
418 .await?;
419
420 let mut changes = Vec::new();
421 for last_message in last_messages {
422 if let Some(observed_message) =
423 observed_messages_by_channel_id.get(&last_message.channel_id)
424 {
425 if observed_message.channel_message_id == last_message.id {
426 continue;
427 }
428 }
429 changes.push(proto::UnseenChannelMessage {
430 channel_id: last_message.channel_id.to_proto(),
431 message_id: last_message.id.to_proto(),
432 });
433 }
434
435 Ok(changes)
436 }
437
438 pub async fn remove_channel_message(
439 &self,
440 channel_id: ChannelId,
441 message_id: MessageId,
442 user_id: UserId,
443 ) -> Result<Vec<ConnectionId>> {
444 self.transaction(|tx| async move {
445 let mut rows = channel_chat_participant::Entity::find()
446 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
447 .stream(&*tx)
448 .await?;
449
450 let mut is_participant = false;
451 let mut participant_connection_ids = Vec::new();
452 while let Some(row) = rows.next().await {
453 let row = row?;
454 if row.user_id == user_id {
455 is_participant = true;
456 }
457 participant_connection_ids.push(row.connection());
458 }
459 drop(rows);
460
461 if !is_participant {
462 Err(anyhow!("not a chat participant"))?;
463 }
464
465 let result = channel_message::Entity::delete_by_id(message_id)
466 .filter(channel_message::Column::SenderId.eq(user_id))
467 .exec(&*tx)
468 .await?;
469
470 if result.rows_affected == 0 {
471 if self
472 .check_user_is_channel_admin(channel_id, user_id, &*tx)
473 .await
474 .is_ok()
475 {
476 let result = channel_message::Entity::delete_by_id(message_id)
477 .exec(&*tx)
478 .await?;
479 if result.rows_affected == 0 {
480 Err(anyhow!("no such message"))?;
481 }
482 } else {
483 Err(anyhow!("operation could not be completed"))?;
484 }
485 }
486
487 Ok(participant_connection_ids)
488 })
489 .await
490 }
491}