1use super::*;
2use time::OffsetDateTime;
3
4impl Database {
5 pub async fn join_channel_chat(
6 &self,
7 channel_id: ChannelId,
8 connection_id: ConnectionId,
9 user_id: UserId,
10 ) -> Result<()> {
11 self.transaction(|tx| async move {
12 self.check_user_is_channel_member(channel_id, user_id, &*tx)
13 .await?;
14 channel_chat_participant::ActiveModel {
15 id: ActiveValue::NotSet,
16 channel_id: ActiveValue::Set(channel_id),
17 user_id: ActiveValue::Set(user_id),
18 connection_id: ActiveValue::Set(connection_id.id as i32),
19 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
20 }
21 .insert(&*tx)
22 .await?;
23 Ok(())
24 })
25 .await
26 }
27
28 pub async fn channel_chat_connection_lost(
29 &self,
30 connection_id: ConnectionId,
31 tx: &DatabaseTransaction,
32 ) -> Result<()> {
33 channel_chat_participant::Entity::delete_many()
34 .filter(
35 Condition::all()
36 .add(
37 channel_chat_participant::Column::ConnectionServerId
38 .eq(connection_id.owner_id),
39 )
40 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
41 )
42 .exec(tx)
43 .await?;
44 Ok(())
45 }
46
47 pub async fn leave_channel_chat(
48 &self,
49 channel_id: ChannelId,
50 connection_id: ConnectionId,
51 _user_id: UserId,
52 ) -> Result<()> {
53 self.transaction(|tx| async move {
54 channel_chat_participant::Entity::delete_many()
55 .filter(
56 Condition::all()
57 .add(
58 channel_chat_participant::Column::ConnectionServerId
59 .eq(connection_id.owner_id),
60 )
61 .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
62 .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
63 )
64 .exec(&*tx)
65 .await?;
66
67 Ok(())
68 })
69 .await
70 }
71
72 pub async fn get_channel_messages(
73 &self,
74 channel_id: ChannelId,
75 user_id: UserId,
76 count: usize,
77 before_message_id: Option<MessageId>,
78 ) -> Result<Vec<proto::ChannelMessage>> {
79 self.transaction(|tx| async move {
80 self.check_user_is_channel_member(channel_id, user_id, &*tx)
81 .await?;
82
83 let mut condition =
84 Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
85
86 if let Some(before_message_id) = before_message_id {
87 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
88 }
89
90 let mut rows = channel_message::Entity::find()
91 .filter(condition)
92 .order_by_desc(channel_message::Column::Id)
93 .limit(count as u64)
94 .stream(&*tx)
95 .await?;
96
97 let mut messages = Vec::new();
98 while let Some(row) = rows.next().await {
99 let row = row?;
100 let nonce = row.nonce.as_u64_pair();
101 messages.push(proto::ChannelMessage {
102 id: row.id.to_proto(),
103 sender_id: row.sender_id.to_proto(),
104 body: row.body,
105 timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
106 nonce: Some(proto::Nonce {
107 upper_half: nonce.0,
108 lower_half: nonce.1,
109 }),
110 });
111 }
112 drop(rows);
113 messages.reverse();
114 Ok(messages)
115 })
116 .await
117 }
118
119 pub async fn create_channel_message(
120 &self,
121 channel_id: ChannelId,
122 user_id: UserId,
123 body: &str,
124 timestamp: OffsetDateTime,
125 nonce: u128,
126 ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
127 self.transaction(|tx| async move {
128 let mut rows = channel_chat_participant::Entity::find()
129 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
130 .stream(&*tx)
131 .await?;
132
133 let mut is_participant = false;
134 let mut participant_connection_ids = Vec::new();
135 let mut participant_user_ids = Vec::new();
136 while let Some(row) = rows.next().await {
137 let row = row?;
138 if row.user_id == user_id {
139 is_participant = true;
140 }
141 participant_user_ids.push(row.user_id);
142 participant_connection_ids.push(row.connection());
143 }
144 drop(rows);
145
146 if !is_participant {
147 Err(anyhow!("not a chat participant"))?;
148 }
149
150 let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
151 let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
152
153 let message = channel_message::Entity::insert(channel_message::ActiveModel {
154 channel_id: ActiveValue::Set(channel_id),
155 sender_id: ActiveValue::Set(user_id),
156 body: ActiveValue::Set(body.to_string()),
157 sent_at: ActiveValue::Set(timestamp),
158 nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
159 id: ActiveValue::NotSet,
160 })
161 .on_conflict(
162 OnConflict::column(channel_message::Column::Nonce)
163 .update_column(channel_message::Column::Nonce)
164 .to_owned(),
165 )
166 .exec(&*tx)
167 .await?;
168
169 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
170 enum QueryConnectionId {
171 ConnectionId,
172 }
173
174 // Observe this message for the sender
175 self.observe_channel_message_internal(
176 channel_id,
177 user_id,
178 message.last_insert_id,
179 &*tx,
180 )
181 .await?;
182
183 let mut channel_members = self
184 .get_channel_participants_internal(channel_id, &*tx)
185 .await?;
186 channel_members.retain(|member| !participant_user_ids.contains(member));
187
188 Ok((
189 message.last_insert_id,
190 participant_connection_ids,
191 channel_members,
192 ))
193 })
194 .await
195 }
196
197 pub async fn observe_channel_message(
198 &self,
199 channel_id: ChannelId,
200 user_id: UserId,
201 message_id: MessageId,
202 ) -> Result<()> {
203 self.transaction(|tx| async move {
204 self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
205 .await?;
206 Ok(())
207 })
208 .await
209 }
210
211 async fn observe_channel_message_internal(
212 &self,
213 channel_id: ChannelId,
214 user_id: UserId,
215 message_id: MessageId,
216 tx: &DatabaseTransaction,
217 ) -> Result<()> {
218 observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
219 user_id: ActiveValue::Set(user_id),
220 channel_id: ActiveValue::Set(channel_id),
221 channel_message_id: ActiveValue::Set(message_id),
222 })
223 .on_conflict(
224 OnConflict::columns([
225 observed_channel_messages::Column::ChannelId,
226 observed_channel_messages::Column::UserId,
227 ])
228 .update_column(observed_channel_messages::Column::ChannelMessageId)
229 .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
230 .to_owned(),
231 )
232 // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
233 .exec_without_returning(&*tx)
234 .await?;
235 Ok(())
236 }
237
238 pub async fn unseen_channel_messages(
239 &self,
240 user_id: UserId,
241 channel_ids: &[ChannelId],
242 tx: &DatabaseTransaction,
243 ) -> Result<Vec<proto::UnseenChannelMessage>> {
244 let mut observed_messages_by_channel_id = HashMap::default();
245 let mut rows = observed_channel_messages::Entity::find()
246 .filter(observed_channel_messages::Column::UserId.eq(user_id))
247 .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
248 .stream(&*tx)
249 .await?;
250
251 while let Some(row) = rows.next().await {
252 let row = row?;
253 observed_messages_by_channel_id.insert(row.channel_id, row);
254 }
255 drop(rows);
256 let mut values = String::new();
257 for id in channel_ids {
258 if !values.is_empty() {
259 values.push_str(", ");
260 }
261 write!(&mut values, "({})", id).unwrap();
262 }
263
264 if values.is_empty() {
265 return Ok(Default::default());
266 }
267
268 let sql = format!(
269 r#"
270 SELECT
271 *
272 FROM (
273 SELECT
274 *,
275 row_number() OVER (
276 PARTITION BY channel_id
277 ORDER BY id DESC
278 ) as row_number
279 FROM channel_messages
280 WHERE
281 channel_id in ({values})
282 ) AS messages
283 WHERE
284 row_number = 1
285 "#,
286 );
287
288 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
289 let last_messages = channel_message::Model::find_by_statement(stmt)
290 .all(&*tx)
291 .await?;
292
293 let mut changes = Vec::new();
294 for last_message in last_messages {
295 if let Some(observed_message) =
296 observed_messages_by_channel_id.get(&last_message.channel_id)
297 {
298 if observed_message.channel_message_id == last_message.id {
299 continue;
300 }
301 }
302 changes.push(proto::UnseenChannelMessage {
303 channel_id: last_message.channel_id.to_proto(),
304 message_id: last_message.id.to_proto(),
305 });
306 }
307
308 Ok(changes)
309 }
310
311 pub async fn remove_channel_message(
312 &self,
313 channel_id: ChannelId,
314 message_id: MessageId,
315 user_id: UserId,
316 ) -> Result<Vec<ConnectionId>> {
317 self.transaction(|tx| async move {
318 let mut rows = channel_chat_participant::Entity::find()
319 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
320 .stream(&*tx)
321 .await?;
322
323 let mut is_participant = false;
324 let mut participant_connection_ids = Vec::new();
325 while let Some(row) = rows.next().await {
326 let row = row?;
327 if row.user_id == user_id {
328 is_participant = true;
329 }
330 participant_connection_ids.push(row.connection());
331 }
332 drop(rows);
333
334 if !is_participant {
335 Err(anyhow!("not a chat participant"))?;
336 }
337
338 let result = channel_message::Entity::delete_by_id(message_id)
339 .filter(channel_message::Column::SenderId.eq(user_id))
340 .exec(&*tx)
341 .await?;
342
343 if result.rows_affected == 0 {
344 if self
345 .check_user_is_channel_admin(channel_id, user_id, &*tx)
346 .await
347 .is_ok()
348 {
349 let result = channel_message::Entity::delete_by_id(message_id)
350 .exec(&*tx)
351 .await?;
352 if result.rows_affected == 0 {
353 Err(anyhow!("no such message"))?;
354 }
355 } else {
356 Err(anyhow!("operation could not be completed"))?;
357 }
358 }
359
360 Ok(participant_connection_ids)
361 })
362 .await
363 }
364}