1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use rand::prelude::*;
11use std::{collections::HashSet, mem, ops::Range, sync::Arc};
12use sum_tree::{Bias, SumTree};
13use time::OffsetDateTime;
14use util::{post_inc, ResultExt as _, TryFutureExt};
15
16pub struct ChannelChat {
17 channel: Arc<Channel>,
18 messages: SumTree<ChannelMessage>,
19 channel_store: ModelHandle<ChannelStore>,
20 loaded_all_messages: bool,
21 last_acknowledged_id: Option<u64>,
22 next_pending_message_id: usize,
23 user_store: ModelHandle<UserStore>,
24 rpc: Arc<Client>,
25 outgoing_messages_lock: Arc<Mutex<()>>,
26 rng: StdRng,
27 _subscription: Subscription,
28}
29
30#[derive(Clone, Debug)]
31pub struct ChannelMessage {
32 pub id: ChannelMessageId,
33 pub body: String,
34 pub timestamp: OffsetDateTime,
35 pub sender: Arc<User>,
36 pub nonce: u128,
37}
38
39#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
40pub enum ChannelMessageId {
41 Saved(u64),
42 Pending(usize),
43}
44
45#[derive(Clone, Debug, Default)]
46pub struct ChannelMessageSummary {
47 max_id: ChannelMessageId,
48 count: usize,
49}
50
51#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
52struct Count(usize);
53
54#[derive(Clone, Debug, PartialEq)]
55pub enum ChannelChatEvent {
56 MessagesUpdated {
57 old_range: Range<usize>,
58 new_count: usize,
59 },
60 NewMessage {
61 channel_id: ChannelId,
62 message_id: u64,
63 },
64}
65
66pub fn init(client: &Arc<Client>) {
67 client.add_model_message_handler(ChannelChat::handle_message_sent);
68 client.add_model_message_handler(ChannelChat::handle_message_removed);
69}
70
71impl Entity for ChannelChat {
72 type Event = ChannelChatEvent;
73
74 fn release(&mut self, _: &mut AppContext) {
75 self.rpc
76 .send(proto::LeaveChannelChat {
77 channel_id: self.channel.id,
78 })
79 .log_err();
80 }
81}
82
83impl ChannelChat {
84 pub async fn new(
85 channel: Arc<Channel>,
86 channel_store: ModelHandle<ChannelStore>,
87 user_store: ModelHandle<UserStore>,
88 client: Arc<Client>,
89 mut cx: AsyncAppContext,
90 ) -> Result<ModelHandle<Self>> {
91 let channel_id = channel.id;
92 let subscription = client.subscribe_to_entity(channel_id).unwrap();
93
94 let response = client
95 .request(proto::JoinChannelChat { channel_id })
96 .await?;
97 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
98 let loaded_all_messages = response.done;
99
100 Ok(cx.add_model(|cx| {
101 let mut this = Self {
102 channel,
103 user_store,
104 channel_store,
105 rpc: client,
106 outgoing_messages_lock: Default::default(),
107 messages: Default::default(),
108 loaded_all_messages,
109 next_pending_message_id: 0,
110 last_acknowledged_id: None,
111 rng: StdRng::from_entropy(),
112 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
113 };
114 this.insert_messages(messages, cx);
115 this
116 }))
117 }
118
119 pub fn channel(&self) -> &Arc<Channel> {
120 &self.channel
121 }
122
123 pub fn send_message(
124 &mut self,
125 body: String,
126 cx: &mut ModelContext<Self>,
127 ) -> Result<Task<Result<()>>> {
128 if body.is_empty() {
129 Err(anyhow!("message body can't be empty"))?;
130 }
131
132 let current_user = self
133 .user_store
134 .read(cx)
135 .current_user()
136 .ok_or_else(|| anyhow!("current_user is not present"))?;
137
138 let channel_id = self.channel.id;
139 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
140 let nonce = self.rng.gen();
141 self.insert_messages(
142 SumTree::from_item(
143 ChannelMessage {
144 id: pending_id,
145 body: body.clone(),
146 sender: current_user,
147 timestamp: OffsetDateTime::now_utc(),
148 nonce,
149 },
150 &(),
151 ),
152 cx,
153 );
154 let user_store = self.user_store.clone();
155 let rpc = self.rpc.clone();
156 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
157 Ok(cx.spawn(|this, mut cx| async move {
158 let outgoing_message_guard = outgoing_messages_lock.lock().await;
159 let request = rpc.request(proto::SendChannelMessage {
160 channel_id,
161 body,
162 nonce: Some(nonce.into()),
163 });
164 let response = request.await?;
165 drop(outgoing_message_guard);
166 let message = ChannelMessage::from_proto(
167 response.message.ok_or_else(|| anyhow!("invalid message"))?,
168 &user_store,
169 &mut cx,
170 )
171 .await?;
172 this.update(&mut cx, |this, cx| {
173 this.insert_messages(SumTree::from_item(message, &()), cx);
174 Ok(())
175 })
176 }))
177 }
178
179 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
180 let response = self.rpc.request(proto::RemoveChannelMessage {
181 channel_id: self.channel.id,
182 message_id: id,
183 });
184 cx.spawn(|this, mut cx| async move {
185 response.await?;
186
187 this.update(&mut cx, |this, cx| {
188 this.message_removed(id, cx);
189 Ok(())
190 })
191 })
192 }
193
194 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
195 if !self.loaded_all_messages {
196 let rpc = self.rpc.clone();
197 let user_store = self.user_store.clone();
198 let channel_id = self.channel.id;
199 if let Some(before_message_id) =
200 self.messages.first().and_then(|message| match message.id {
201 ChannelMessageId::Saved(id) => Some(id),
202 ChannelMessageId::Pending(_) => None,
203 })
204 {
205 cx.spawn(|this, mut cx| {
206 async move {
207 let response = rpc
208 .request(proto::GetChannelMessages {
209 channel_id,
210 before_message_id,
211 })
212 .await?;
213 let loaded_all_messages = response.done;
214 let messages =
215 messages_from_proto(response.messages, &user_store, &mut cx).await?;
216 this.update(&mut cx, |this, cx| {
217 this.loaded_all_messages = loaded_all_messages;
218 this.insert_messages(messages, cx);
219 });
220 anyhow::Ok(())
221 }
222 .log_err()
223 })
224 .detach();
225 return true;
226 }
227 }
228 false
229 }
230
231 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
232 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
233 if self
234 .last_acknowledged_id
235 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
236 {
237 self.rpc
238 .send(proto::AckChannelMessage {
239 channel_id: self.channel.id,
240 message_id: latest_message_id,
241 })
242 .ok();
243 self.last_acknowledged_id = Some(latest_message_id);
244 self.channel_store.update(cx, |store, cx| {
245 store.acknowledge_message_id(self.channel.id, latest_message_id, cx);
246 });
247 }
248 }
249 }
250
251 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
252 let user_store = self.user_store.clone();
253 let rpc = self.rpc.clone();
254 let channel_id = self.channel.id;
255 cx.spawn(|this, mut cx| {
256 async move {
257 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
258 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
259 let loaded_all_messages = response.done;
260
261 let pending_messages = this.update(&mut cx, |this, cx| {
262 if let Some((first_new_message, last_old_message)) =
263 messages.first().zip(this.messages.last())
264 {
265 if first_new_message.id > last_old_message.id {
266 let old_messages = mem::take(&mut this.messages);
267 cx.emit(ChannelChatEvent::MessagesUpdated {
268 old_range: 0..old_messages.summary().count,
269 new_count: 0,
270 });
271 this.loaded_all_messages = loaded_all_messages;
272 }
273 }
274
275 this.insert_messages(messages, cx);
276 if loaded_all_messages {
277 this.loaded_all_messages = loaded_all_messages;
278 }
279
280 this.pending_messages().cloned().collect::<Vec<_>>()
281 });
282
283 for pending_message in pending_messages {
284 let request = rpc.request(proto::SendChannelMessage {
285 channel_id,
286 body: pending_message.body,
287 nonce: Some(pending_message.nonce.into()),
288 });
289 let response = request.await?;
290 let message = ChannelMessage::from_proto(
291 response.message.ok_or_else(|| anyhow!("invalid message"))?,
292 &user_store,
293 &mut cx,
294 )
295 .await?;
296 this.update(&mut cx, |this, cx| {
297 this.insert_messages(SumTree::from_item(message, &()), cx);
298 });
299 }
300
301 anyhow::Ok(())
302 }
303 .log_err()
304 })
305 .detach();
306 }
307
308 pub fn message_count(&self) -> usize {
309 self.messages.summary().count
310 }
311
312 pub fn messages(&self) -> &SumTree<ChannelMessage> {
313 &self.messages
314 }
315
316 pub fn message(&self, ix: usize) -> &ChannelMessage {
317 let mut cursor = self.messages.cursor::<Count>();
318 cursor.seek(&Count(ix), Bias::Right, &());
319 cursor.item().unwrap()
320 }
321
322 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
323 let mut cursor = self.messages.cursor::<Count>();
324 cursor.seek(&Count(range.start), Bias::Right, &());
325 cursor.take(range.len())
326 }
327
328 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
329 let mut cursor = self.messages.cursor::<ChannelMessageId>();
330 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
331 cursor
332 }
333
334 async fn handle_message_sent(
335 this: ModelHandle<Self>,
336 message: TypedEnvelope<proto::ChannelMessageSent>,
337 _: Arc<Client>,
338 mut cx: AsyncAppContext,
339 ) -> Result<()> {
340 let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
341 let message = message
342 .payload
343 .message
344 .ok_or_else(|| anyhow!("empty message"))?;
345 let message_id = message.id;
346
347 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
348 this.update(&mut cx, |this, cx| {
349 this.insert_messages(SumTree::from_item(message, &()), cx);
350 cx.emit(ChannelChatEvent::NewMessage {
351 channel_id: this.channel.id,
352 message_id,
353 })
354 });
355
356 Ok(())
357 }
358
359 async fn handle_message_removed(
360 this: ModelHandle<Self>,
361 message: TypedEnvelope<proto::RemoveChannelMessage>,
362 _: Arc<Client>,
363 mut cx: AsyncAppContext,
364 ) -> Result<()> {
365 this.update(&mut cx, |this, cx| {
366 this.message_removed(message.payload.message_id, cx)
367 });
368 Ok(())
369 }
370
371 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
372 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
373 let nonces = messages
374 .cursor::<()>()
375 .map(|m| m.nonce)
376 .collect::<HashSet<_>>();
377
378 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
379 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
380 let start_ix = old_cursor.start().1 .0;
381 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
382 let removed_count = removed_messages.summary().count;
383 let new_count = messages.summary().count;
384 let end_ix = start_ix + removed_count;
385
386 new_messages.append(messages, &());
387
388 let mut ranges = Vec::<Range<usize>>::new();
389 if new_messages.last().unwrap().is_pending() {
390 new_messages.append(old_cursor.suffix(&()), &());
391 } else {
392 new_messages.append(
393 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
394 &(),
395 );
396
397 while let Some(message) = old_cursor.item() {
398 let message_ix = old_cursor.start().1 .0;
399 if nonces.contains(&message.nonce) {
400 if ranges.last().map_or(false, |r| r.end == message_ix) {
401 ranges.last_mut().unwrap().end += 1;
402 } else {
403 ranges.push(message_ix..message_ix + 1);
404 }
405 } else {
406 new_messages.push(message.clone(), &());
407 }
408 old_cursor.next(&());
409 }
410 }
411
412 drop(old_cursor);
413 self.messages = new_messages;
414
415 for range in ranges.into_iter().rev() {
416 cx.emit(ChannelChatEvent::MessagesUpdated {
417 old_range: range,
418 new_count: 0,
419 });
420 }
421 cx.emit(ChannelChatEvent::MessagesUpdated {
422 old_range: start_ix..end_ix,
423 new_count,
424 });
425
426 cx.notify();
427 }
428 }
429
430 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
431 let mut cursor = self.messages.cursor::<ChannelMessageId>();
432 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
433 if let Some(item) = cursor.item() {
434 if item.id == ChannelMessageId::Saved(id) {
435 let ix = messages.summary().count;
436 cursor.next(&());
437 messages.append(cursor.suffix(&()), &());
438 drop(cursor);
439 self.messages = messages;
440 cx.emit(ChannelChatEvent::MessagesUpdated {
441 old_range: ix..ix + 1,
442 new_count: 0,
443 });
444 }
445 }
446 }
447}
448
449async fn messages_from_proto(
450 proto_messages: Vec<proto::ChannelMessage>,
451 user_store: &ModelHandle<UserStore>,
452 cx: &mut AsyncAppContext,
453) -> Result<SumTree<ChannelMessage>> {
454 let unique_user_ids = proto_messages
455 .iter()
456 .map(|m| m.sender_id)
457 .collect::<HashSet<_>>()
458 .into_iter()
459 .collect();
460 user_store
461 .update(cx, |user_store, cx| {
462 user_store.get_users(unique_user_ids, cx)
463 })
464 .await?;
465
466 let mut messages = Vec::with_capacity(proto_messages.len());
467 for message in proto_messages {
468 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
469 }
470 let mut result = SumTree::new();
471 result.extend(messages, &());
472 Ok(result)
473}
474
475impl ChannelMessage {
476 pub async fn from_proto(
477 message: proto::ChannelMessage,
478 user_store: &ModelHandle<UserStore>,
479 cx: &mut AsyncAppContext,
480 ) -> Result<Self> {
481 let sender = user_store
482 .update(cx, |user_store, cx| {
483 user_store.get_user(message.sender_id, cx)
484 })
485 .await?;
486 Ok(ChannelMessage {
487 id: ChannelMessageId::Saved(message.id),
488 body: message.body,
489 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
490 sender,
491 nonce: message
492 .nonce
493 .ok_or_else(|| anyhow!("nonce is required"))?
494 .into(),
495 })
496 }
497
498 pub fn is_pending(&self) -> bool {
499 matches!(self.id, ChannelMessageId::Pending(_))
500 }
501}
502
503impl sum_tree::Item for ChannelMessage {
504 type Summary = ChannelMessageSummary;
505
506 fn summary(&self) -> Self::Summary {
507 ChannelMessageSummary {
508 max_id: self.id,
509 count: 1,
510 }
511 }
512}
513
514impl Default for ChannelMessageId {
515 fn default() -> Self {
516 Self::Saved(0)
517 }
518}
519
520impl sum_tree::Summary for ChannelMessageSummary {
521 type Context = ();
522
523 fn add_summary(&mut self, summary: &Self, _: &()) {
524 self.max_id = summary.max_id;
525 self.count += summary.count;
526 }
527}
528
529impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
530 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
531 debug_assert!(summary.max_id > *self);
532 *self = summary.max_id;
533 }
534}
535
536impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
537 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
538 self.0 += summary.count;
539 }
540}