1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope, UserId,
7};
8use collections::HashSet;
9use futures::lock::Mutex;
10use gpui::{
11 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
12};
13use rand::prelude::*;
14use std::{
15 ops::{ControlFlow, Range},
16 sync::Arc,
17};
18use sum_tree::{Bias, SumTree};
19use time::OffsetDateTime;
20use util::{post_inc, ResultExt as _, TryFutureExt};
21
22pub struct ChannelChat {
23 pub channel_id: ChannelId,
24 messages: SumTree<ChannelMessage>,
25 acknowledged_message_ids: HashSet<u64>,
26 channel_store: Model<ChannelStore>,
27 loaded_all_messages: bool,
28 last_acknowledged_id: Option<u64>,
29 next_pending_message_id: usize,
30 first_loaded_message_id: Option<u64>,
31 user_store: Model<UserStore>,
32 rpc: Arc<Client>,
33 outgoing_messages_lock: Arc<Mutex<()>>,
34 rng: StdRng,
35 _subscription: Subscription,
36}
37
38#[derive(Debug, PartialEq, Eq)]
39pub struct MessageParams {
40 pub text: String,
41 pub mentions: Vec<(Range<usize>, UserId)>,
42 pub reply_to_message_id: Option<u64>,
43}
44
45#[derive(Clone, Debug)]
46pub struct ChannelMessage {
47 pub id: ChannelMessageId,
48 pub body: String,
49 pub timestamp: OffsetDateTime,
50 pub sender: Arc<User>,
51 pub nonce: u128,
52 pub mentions: Vec<(Range<usize>, UserId)>,
53 pub reply_to_message_id: Option<u64>,
54}
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
57pub enum ChannelMessageId {
58 Saved(u64),
59 Pending(usize),
60}
61
62impl Into<Option<u64>> for ChannelMessageId {
63 fn into(self) -> Option<u64> {
64 match self {
65 ChannelMessageId::Saved(id) => Some(id),
66 ChannelMessageId::Pending(_) => None,
67 }
68 }
69}
70
71#[derive(Clone, Debug, Default)]
72pub struct ChannelMessageSummary {
73 max_id: ChannelMessageId,
74 count: usize,
75}
76
77#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
78struct Count(usize);
79
80#[derive(Clone, Debug, PartialEq)]
81pub enum ChannelChatEvent {
82 MessagesUpdated {
83 old_range: Range<usize>,
84 new_count: usize,
85 },
86 NewMessage {
87 channel_id: ChannelId,
88 message_id: u64,
89 },
90}
91
92impl EventEmitter<ChannelChatEvent> for ChannelChat {}
93pub fn init(client: &Arc<Client>) {
94 client.add_model_message_handler(ChannelChat::handle_message_sent);
95 client.add_model_message_handler(ChannelChat::handle_message_removed);
96}
97
98impl ChannelChat {
99 pub async fn new(
100 channel: Arc<Channel>,
101 channel_store: Model<ChannelStore>,
102 user_store: Model<UserStore>,
103 client: Arc<Client>,
104 mut cx: AsyncAppContext,
105 ) -> Result<Model<Self>> {
106 let channel_id = channel.id;
107 let subscription = client.subscribe_to_entity(channel_id).unwrap();
108
109 let response = client
110 .request(proto::JoinChannelChat { channel_id })
111 .await?;
112
113 let handle = cx.new_model(|cx| {
114 cx.on_release(Self::release).detach();
115 Self {
116 channel_id: channel.id,
117 user_store: user_store.clone(),
118 channel_store,
119 rpc: client.clone(),
120 outgoing_messages_lock: Default::default(),
121 messages: Default::default(),
122 acknowledged_message_ids: Default::default(),
123 loaded_all_messages: false,
124 next_pending_message_id: 0,
125 last_acknowledged_id: None,
126 rng: StdRng::from_entropy(),
127 first_loaded_message_id: None,
128 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
129 }
130 })?;
131 Self::handle_loaded_messages(
132 handle.downgrade(),
133 user_store,
134 client,
135 response.messages,
136 response.done,
137 &mut cx,
138 )
139 .await?;
140 Ok(handle)
141 }
142
143 fn release(&mut self, _: &mut AppContext) {
144 self.rpc
145 .send(proto::LeaveChannelChat {
146 channel_id: self.channel_id,
147 })
148 .log_err();
149 }
150
151 pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
152 self.channel_store
153 .read(cx)
154 .channel_for_id(self.channel_id)
155 .cloned()
156 }
157
158 pub fn client(&self) -> &Arc<Client> {
159 &self.rpc
160 }
161
162 pub fn send_message(
163 &mut self,
164 message: MessageParams,
165 cx: &mut ModelContext<Self>,
166 ) -> Result<Task<Result<u64>>> {
167 if message.text.trim().is_empty() {
168 Err(anyhow!("message body can't be empty"))?;
169 }
170
171 let current_user = self
172 .user_store
173 .read(cx)
174 .current_user()
175 .ok_or_else(|| anyhow!("current_user is not present"))?;
176
177 let channel_id = self.channel_id;
178 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
179 let nonce = self.rng.gen();
180 self.insert_messages(
181 SumTree::from_item(
182 ChannelMessage {
183 id: pending_id,
184 body: message.text.clone(),
185 sender: current_user,
186 timestamp: OffsetDateTime::now_utc(),
187 mentions: message.mentions.clone(),
188 nonce,
189 reply_to_message_id: message.reply_to_message_id,
190 },
191 &(),
192 ),
193 cx,
194 );
195 let user_store = self.user_store.clone();
196 let rpc = self.rpc.clone();
197 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
198
199 // todo - handle messages that fail to send (e.g. >1024 chars)
200 Ok(cx.spawn(move |this, mut cx| async move {
201 let outgoing_message_guard = outgoing_messages_lock.lock().await;
202 let request = rpc.request(proto::SendChannelMessage {
203 channel_id,
204 body: message.text,
205 nonce: Some(nonce.into()),
206 mentions: mentions_to_proto(&message.mentions),
207 reply_to_message_id: message.reply_to_message_id,
208 });
209 let response = request.await?;
210 drop(outgoing_message_guard);
211 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
212 let id = response.id;
213 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
214 this.update(&mut cx, |this, cx| {
215 this.insert_messages(SumTree::from_item(message, &()), cx);
216 })?;
217 Ok(id)
218 }))
219 }
220
221 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
222 let response = self.rpc.request(proto::RemoveChannelMessage {
223 channel_id: self.channel_id,
224 message_id: id,
225 });
226 cx.spawn(move |this, mut cx| async move {
227 response.await?;
228 this.update(&mut cx, |this, cx| {
229 this.message_removed(id, cx);
230 })?;
231 Ok(())
232 })
233 }
234
235 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
236 if self.loaded_all_messages {
237 return None;
238 }
239
240 let rpc = self.rpc.clone();
241 let user_store = self.user_store.clone();
242 let channel_id = self.channel_id;
243 let before_message_id = self.first_loaded_message_id()?;
244 Some(cx.spawn(move |this, mut cx| {
245 async move {
246 let response = rpc
247 .request(proto::GetChannelMessages {
248 channel_id,
249 before_message_id,
250 })
251 .await?;
252 Self::handle_loaded_messages(
253 this,
254 user_store,
255 rpc,
256 response.messages,
257 response.done,
258 &mut cx,
259 )
260 .await?;
261
262 anyhow::Ok(())
263 }
264 .log_err()
265 }))
266 }
267
268 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
269 self.first_loaded_message_id
270 }
271
272 /// Load a message by its id, if it's already stored locally.
273 pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> {
274 self.messages.iter().find(|message| match message.id {
275 ChannelMessageId::Saved(message_id) => message_id == id,
276 ChannelMessageId::Pending(_) => false,
277 })
278 }
279
280 /// Load all of the chat messages since a certain message id.
281 ///
282 /// For now, we always maintain a suffix of the channel's messages.
283 pub async fn load_history_since_message(
284 chat: Model<Self>,
285 message_id: u64,
286 mut cx: AsyncAppContext,
287 ) -> Option<usize> {
288 loop {
289 let step = chat
290 .update(&mut cx, |chat, cx| {
291 if let Some(first_id) = chat.first_loaded_message_id() {
292 if first_id <= message_id {
293 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
294 let message_id = ChannelMessageId::Saved(message_id);
295 cursor.seek(&message_id, Bias::Left, &());
296 return ControlFlow::Break(
297 if cursor
298 .item()
299 .map_or(false, |message| message.id == message_id)
300 {
301 Some(cursor.start().1 .0)
302 } else {
303 None
304 },
305 );
306 }
307 }
308 ControlFlow::Continue(chat.load_more_messages(cx))
309 })
310 .log_err()?;
311 match step {
312 ControlFlow::Break(ix) => return ix,
313 ControlFlow::Continue(task) => task?.await?,
314 }
315 }
316 }
317
318 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
319 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
320 if self
321 .last_acknowledged_id
322 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
323 {
324 self.rpc
325 .send(proto::AckChannelMessage {
326 channel_id: self.channel_id,
327 message_id: latest_message_id,
328 })
329 .ok();
330 self.last_acknowledged_id = Some(latest_message_id);
331 self.channel_store.update(cx, |store, cx| {
332 store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
333 });
334 }
335 }
336 }
337
338 async fn handle_loaded_messages(
339 this: WeakModel<Self>,
340 user_store: Model<UserStore>,
341 rpc: Arc<Client>,
342 proto_messages: Vec<proto::ChannelMessage>,
343 loaded_all_messages: bool,
344 cx: &mut AsyncAppContext,
345 ) -> Result<()> {
346 let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?;
347
348 let first_loaded_message_id = loaded_messages.first().map(|m| m.id);
349 let loaded_message_ids = this.update(cx, |this, _| {
350 let mut loaded_message_ids: HashSet<u64> = HashSet::default();
351 for message in loaded_messages.iter() {
352 if let Some(saved_message_id) = message.id.into() {
353 loaded_message_ids.insert(saved_message_id);
354 }
355 }
356 for message in this.messages.iter() {
357 if let Some(saved_message_id) = message.id.into() {
358 loaded_message_ids.insert(saved_message_id);
359 }
360 }
361 loaded_message_ids
362 })?;
363
364 let missing_ancestors = loaded_messages
365 .iter()
366 .filter_map(|message| {
367 if let Some(ancestor_id) = message.reply_to_message_id {
368 if !loaded_message_ids.contains(&ancestor_id) {
369 return Some(ancestor_id);
370 }
371 }
372 None
373 })
374 .collect::<Vec<_>>();
375
376 let loaded_ancestors = if missing_ancestors.is_empty() {
377 None
378 } else {
379 let response = rpc
380 .request(proto::GetChannelMessagesById {
381 message_ids: missing_ancestors,
382 })
383 .await?;
384 Some(messages_from_proto(response.messages, &user_store, cx).await?)
385 };
386 this.update(cx, |this, cx| {
387 this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into());
388 this.loaded_all_messages = loaded_all_messages;
389 this.insert_messages(loaded_messages, cx);
390 if let Some(loaded_ancestors) = loaded_ancestors {
391 this.insert_messages(loaded_ancestors, cx);
392 }
393 })?;
394
395 Ok(())
396 }
397
398 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
399 let user_store = self.user_store.clone();
400 let rpc = self.rpc.clone();
401 let channel_id = self.channel_id;
402 cx.spawn(move |this, mut cx| {
403 async move {
404 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
405 Self::handle_loaded_messages(
406 this.clone(),
407 user_store.clone(),
408 rpc.clone(),
409 response.messages,
410 response.done,
411 &mut cx,
412 )
413 .await?;
414
415 let pending_messages = this.update(&mut cx, |this, _| {
416 this.pending_messages().cloned().collect::<Vec<_>>()
417 })?;
418
419 for pending_message in pending_messages {
420 let request = rpc.request(proto::SendChannelMessage {
421 channel_id,
422 body: pending_message.body,
423 mentions: mentions_to_proto(&pending_message.mentions),
424 nonce: Some(pending_message.nonce.into()),
425 reply_to_message_id: pending_message.reply_to_message_id,
426 });
427 let response = request.await?;
428 let message = ChannelMessage::from_proto(
429 response.message.ok_or_else(|| anyhow!("invalid message"))?,
430 &user_store,
431 &mut cx,
432 )
433 .await?;
434 this.update(&mut cx, |this, cx| {
435 this.insert_messages(SumTree::from_item(message, &()), cx);
436 })?;
437 }
438
439 anyhow::Ok(())
440 }
441 .log_err()
442 })
443 .detach();
444 }
445
446 pub fn message_count(&self) -> usize {
447 self.messages.summary().count
448 }
449
450 pub fn messages(&self) -> &SumTree<ChannelMessage> {
451 &self.messages
452 }
453
454 pub fn message(&self, ix: usize) -> &ChannelMessage {
455 let mut cursor = self.messages.cursor::<Count>();
456 cursor.seek(&Count(ix), Bias::Right, &());
457 cursor.item().unwrap()
458 }
459
460 pub fn acknowledge_message(&mut self, id: u64) {
461 if self.acknowledged_message_ids.insert(id) {
462 self.rpc
463 .send(proto::AckChannelMessage {
464 channel_id: self.channel_id,
465 message_id: id,
466 })
467 .ok();
468 }
469 }
470
471 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
472 let mut cursor = self.messages.cursor::<Count>();
473 cursor.seek(&Count(range.start), Bias::Right, &());
474 cursor.take(range.len())
475 }
476
477 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
478 let mut cursor = self.messages.cursor::<ChannelMessageId>();
479 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
480 cursor
481 }
482
483 async fn handle_message_sent(
484 this: Model<Self>,
485 message: TypedEnvelope<proto::ChannelMessageSent>,
486 _: Arc<Client>,
487 mut cx: AsyncAppContext,
488 ) -> Result<()> {
489 let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
490 let message = message
491 .payload
492 .message
493 .ok_or_else(|| anyhow!("empty message"))?;
494 let message_id = message.id;
495
496 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
497 this.update(&mut cx, |this, cx| {
498 this.insert_messages(SumTree::from_item(message, &()), cx);
499 cx.emit(ChannelChatEvent::NewMessage {
500 channel_id: this.channel_id,
501 message_id,
502 })
503 })?;
504
505 Ok(())
506 }
507
508 async fn handle_message_removed(
509 this: Model<Self>,
510 message: TypedEnvelope<proto::RemoveChannelMessage>,
511 _: Arc<Client>,
512 mut cx: AsyncAppContext,
513 ) -> Result<()> {
514 this.update(&mut cx, |this, cx| {
515 this.message_removed(message.payload.message_id, cx)
516 })?;
517 Ok(())
518 }
519
520 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
521 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
522 let nonces = messages
523 .cursor::<()>()
524 .map(|m| m.nonce)
525 .collect::<HashSet<_>>();
526
527 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
528 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
529 let start_ix = old_cursor.start().1 .0;
530 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
531 let removed_count = removed_messages.summary().count;
532 let new_count = messages.summary().count;
533 let end_ix = start_ix + removed_count;
534
535 new_messages.append(messages, &());
536
537 let mut ranges = Vec::<Range<usize>>::new();
538 if new_messages.last().unwrap().is_pending() {
539 new_messages.append(old_cursor.suffix(&()), &());
540 } else {
541 new_messages.append(
542 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
543 &(),
544 );
545
546 while let Some(message) = old_cursor.item() {
547 let message_ix = old_cursor.start().1 .0;
548 if nonces.contains(&message.nonce) {
549 if ranges.last().map_or(false, |r| r.end == message_ix) {
550 ranges.last_mut().unwrap().end += 1;
551 } else {
552 ranges.push(message_ix..message_ix + 1);
553 }
554 } else {
555 new_messages.push(message.clone(), &());
556 }
557 old_cursor.next(&());
558 }
559 }
560
561 drop(old_cursor);
562 self.messages = new_messages;
563
564 for range in ranges.into_iter().rev() {
565 cx.emit(ChannelChatEvent::MessagesUpdated {
566 old_range: range,
567 new_count: 0,
568 });
569 }
570 cx.emit(ChannelChatEvent::MessagesUpdated {
571 old_range: start_ix..end_ix,
572 new_count,
573 });
574
575 cx.notify();
576 }
577 }
578
579 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
580 let mut cursor = self.messages.cursor::<ChannelMessageId>();
581 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
582 if let Some(item) = cursor.item() {
583 if item.id == ChannelMessageId::Saved(id) {
584 let ix = messages.summary().count;
585 cursor.next(&());
586 messages.append(cursor.suffix(&()), &());
587 drop(cursor);
588 self.messages = messages;
589 cx.emit(ChannelChatEvent::MessagesUpdated {
590 old_range: ix..ix + 1,
591 new_count: 0,
592 });
593 }
594 }
595 }
596}
597
598async fn messages_from_proto(
599 proto_messages: Vec<proto::ChannelMessage>,
600 user_store: &Model<UserStore>,
601 cx: &mut AsyncAppContext,
602) -> Result<SumTree<ChannelMessage>> {
603 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
604 let mut result = SumTree::new();
605 result.extend(messages, &());
606 Ok(result)
607}
608
609impl ChannelMessage {
610 pub async fn from_proto(
611 message: proto::ChannelMessage,
612 user_store: &Model<UserStore>,
613 cx: &mut AsyncAppContext,
614 ) -> Result<Self> {
615 let sender = user_store
616 .update(cx, |user_store, cx| {
617 user_store.get_user(message.sender_id, cx)
618 })?
619 .await?;
620 Ok(ChannelMessage {
621 id: ChannelMessageId::Saved(message.id),
622 body: message.body,
623 mentions: message
624 .mentions
625 .into_iter()
626 .filter_map(|mention| {
627 let range = mention.range?;
628 Some((range.start as usize..range.end as usize, mention.user_id))
629 })
630 .collect(),
631 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
632 sender,
633 nonce: message
634 .nonce
635 .ok_or_else(|| anyhow!("nonce is required"))?
636 .into(),
637 reply_to_message_id: message.reply_to_message_id,
638 })
639 }
640
641 pub fn is_pending(&self) -> bool {
642 matches!(self.id, ChannelMessageId::Pending(_))
643 }
644
645 pub async fn from_proto_vec(
646 proto_messages: Vec<proto::ChannelMessage>,
647 user_store: &Model<UserStore>,
648 cx: &mut AsyncAppContext,
649 ) -> Result<Vec<Self>> {
650 let unique_user_ids = proto_messages
651 .iter()
652 .map(|m| m.sender_id)
653 .collect::<HashSet<_>>()
654 .into_iter()
655 .collect();
656 user_store
657 .update(cx, |user_store, cx| {
658 user_store.get_users(unique_user_ids, cx)
659 })?
660 .await?;
661
662 let mut messages = Vec::with_capacity(proto_messages.len());
663 for message in proto_messages {
664 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
665 }
666 Ok(messages)
667 }
668}
669
670pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
671 mentions
672 .iter()
673 .map(|(range, user_id)| proto::ChatMention {
674 range: Some(proto::Range {
675 start: range.start as u64,
676 end: range.end as u64,
677 }),
678 user_id: *user_id as u64,
679 })
680 .collect()
681}
682
683impl sum_tree::Item for ChannelMessage {
684 type Summary = ChannelMessageSummary;
685
686 fn summary(&self) -> Self::Summary {
687 ChannelMessageSummary {
688 max_id: self.id,
689 count: 1,
690 }
691 }
692}
693
694impl Default for ChannelMessageId {
695 fn default() -> Self {
696 Self::Saved(0)
697 }
698}
699
700impl sum_tree::Summary for ChannelMessageSummary {
701 type Context = ();
702
703 fn add_summary(&mut self, summary: &Self, _: &()) {
704 self.max_id = summary.max_id;
705 self.count += summary.count;
706 }
707}
708
709impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
710 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
711 debug_assert!(summary.max_id > *self);
712 *self = summary.max_id;
713 }
714}
715
716impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
717 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
718 self.0 += summary.count;
719 }
720}
721
722impl<'a> From<&'a str> for MessageParams {
723 fn from(value: &'a str) -> Self {
724 Self {
725 text: value.into(),
726 mentions: Vec::new(),
727 reply_to_message_id: None,
728 }
729 }
730}