1use crate::{Channel, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 ChannelId, Client, Subscription, TypedEnvelope, UserId,
7};
8use collections::HashSet;
9use futures::lock::Mutex;
10use gpui::{
11 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
12};
13use rand::prelude::*;
14use std::{
15 ops::{ControlFlow, Range},
16 sync::Arc,
17};
18use sum_tree::{Bias, SumTree};
19use time::OffsetDateTime;
20use util::{post_inc, ResultExt as _, TryFutureExt};
21
22pub struct ChannelChat {
23 pub channel_id: ChannelId,
24 messages: SumTree<ChannelMessage>,
25 acknowledged_message_ids: HashSet<u64>,
26 channel_store: Model<ChannelStore>,
27 loaded_all_messages: bool,
28 last_acknowledged_id: Option<u64>,
29 next_pending_message_id: usize,
30 first_loaded_message_id: Option<u64>,
31 user_store: Model<UserStore>,
32 rpc: Arc<Client>,
33 outgoing_messages_lock: Arc<Mutex<()>>,
34 rng: StdRng,
35 _subscription: Subscription,
36}
37
38#[derive(Debug, PartialEq, Eq)]
39pub struct MessageParams {
40 pub text: String,
41 pub mentions: Vec<(Range<usize>, UserId)>,
42 pub reply_to_message_id: Option<u64>,
43}
44
45#[derive(Clone, Debug)]
46pub struct ChannelMessage {
47 pub id: ChannelMessageId,
48 pub body: String,
49 pub timestamp: OffsetDateTime,
50 pub sender: Arc<User>,
51 pub nonce: u128,
52 pub mentions: Vec<(Range<usize>, UserId)>,
53 pub reply_to_message_id: Option<u64>,
54}
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
57pub enum ChannelMessageId {
58 Saved(u64),
59 Pending(usize),
60}
61
62impl Into<Option<u64>> for ChannelMessageId {
63 fn into(self) -> Option<u64> {
64 match self {
65 ChannelMessageId::Saved(id) => Some(id),
66 ChannelMessageId::Pending(_) => None,
67 }
68 }
69}
70
71#[derive(Clone, Debug, Default)]
72pub struct ChannelMessageSummary {
73 max_id: ChannelMessageId,
74 count: usize,
75}
76
77#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
78struct Count(usize);
79
80#[derive(Clone, Debug, PartialEq)]
81pub enum ChannelChatEvent {
82 MessagesUpdated {
83 old_range: Range<usize>,
84 new_count: usize,
85 },
86 NewMessage {
87 channel_id: ChannelId,
88 message_id: u64,
89 },
90}
91
92impl EventEmitter<ChannelChatEvent> for ChannelChat {}
93pub fn init(client: &Arc<Client>) {
94 client.add_model_message_handler(ChannelChat::handle_message_sent);
95 client.add_model_message_handler(ChannelChat::handle_message_removed);
96}
97
98impl ChannelChat {
99 pub async fn new(
100 channel: Arc<Channel>,
101 channel_store: Model<ChannelStore>,
102 user_store: Model<UserStore>,
103 client: Arc<Client>,
104 mut cx: AsyncAppContext,
105 ) -> Result<Model<Self>> {
106 let channel_id = channel.id;
107 let subscription = client.subscribe_to_entity(channel_id.0).unwrap();
108
109 let response = client
110 .request(proto::JoinChannelChat {
111 channel_id: channel_id.0,
112 })
113 .await?;
114
115 let handle = cx.new_model(|cx| {
116 cx.on_release(Self::release).detach();
117 Self {
118 channel_id: channel.id,
119 user_store: user_store.clone(),
120 channel_store,
121 rpc: client.clone(),
122 outgoing_messages_lock: Default::default(),
123 messages: Default::default(),
124 acknowledged_message_ids: Default::default(),
125 loaded_all_messages: false,
126 next_pending_message_id: 0,
127 last_acknowledged_id: None,
128 rng: StdRng::from_entropy(),
129 first_loaded_message_id: None,
130 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
131 }
132 })?;
133 Self::handle_loaded_messages(
134 handle.downgrade(),
135 user_store,
136 client,
137 response.messages,
138 response.done,
139 &mut cx,
140 )
141 .await?;
142 Ok(handle)
143 }
144
145 fn release(&mut self, _: &mut AppContext) {
146 self.rpc
147 .send(proto::LeaveChannelChat {
148 channel_id: self.channel_id.0,
149 })
150 .log_err();
151 }
152
153 pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
154 self.channel_store
155 .read(cx)
156 .channel_for_id(self.channel_id)
157 .cloned()
158 }
159
160 pub fn client(&self) -> &Arc<Client> {
161 &self.rpc
162 }
163
164 pub fn send_message(
165 &mut self,
166 message: MessageParams,
167 cx: &mut ModelContext<Self>,
168 ) -> Result<Task<Result<u64>>> {
169 if message.text.trim().is_empty() {
170 Err(anyhow!("message body can't be empty"))?;
171 }
172
173 let current_user = self
174 .user_store
175 .read(cx)
176 .current_user()
177 .ok_or_else(|| anyhow!("current_user is not present"))?;
178
179 let channel_id = self.channel_id;
180 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
181 let nonce = self.rng.gen();
182 self.insert_messages(
183 SumTree::from_item(
184 ChannelMessage {
185 id: pending_id,
186 body: message.text.clone(),
187 sender: current_user,
188 timestamp: OffsetDateTime::now_utc(),
189 mentions: message.mentions.clone(),
190 nonce,
191 reply_to_message_id: message.reply_to_message_id,
192 },
193 &(),
194 ),
195 cx,
196 );
197 let user_store = self.user_store.clone();
198 let rpc = self.rpc.clone();
199 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
200
201 // todo - handle messages that fail to send (e.g. >1024 chars)
202 Ok(cx.spawn(move |this, mut cx| async move {
203 let outgoing_message_guard = outgoing_messages_lock.lock().await;
204 let request = rpc.request(proto::SendChannelMessage {
205 channel_id: channel_id.0,
206 body: message.text,
207 nonce: Some(nonce.into()),
208 mentions: mentions_to_proto(&message.mentions),
209 reply_to_message_id: message.reply_to_message_id,
210 });
211 let response = request.await?;
212 drop(outgoing_message_guard);
213 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
214 let id = response.id;
215 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
216 this.update(&mut cx, |this, cx| {
217 this.insert_messages(SumTree::from_item(message, &()), cx);
218 })?;
219 Ok(id)
220 }))
221 }
222
223 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
224 let response = self.rpc.request(proto::RemoveChannelMessage {
225 channel_id: self.channel_id.0,
226 message_id: id,
227 });
228 cx.spawn(move |this, mut cx| async move {
229 response.await?;
230 this.update(&mut cx, |this, cx| {
231 this.message_removed(id, cx);
232 })?;
233 Ok(())
234 })
235 }
236
237 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
238 if self.loaded_all_messages {
239 return None;
240 }
241
242 let rpc = self.rpc.clone();
243 let user_store = self.user_store.clone();
244 let channel_id = self.channel_id;
245 let before_message_id = self.first_loaded_message_id()?;
246 Some(cx.spawn(move |this, mut cx| {
247 async move {
248 let response = rpc
249 .request(proto::GetChannelMessages {
250 channel_id: channel_id.0,
251 before_message_id,
252 })
253 .await?;
254 Self::handle_loaded_messages(
255 this,
256 user_store,
257 rpc,
258 response.messages,
259 response.done,
260 &mut cx,
261 )
262 .await?;
263
264 anyhow::Ok(())
265 }
266 .log_err()
267 }))
268 }
269
270 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
271 self.first_loaded_message_id
272 }
273
274 /// Load a message by its id, if it's already stored locally.
275 pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> {
276 self.messages.iter().find(|message| match message.id {
277 ChannelMessageId::Saved(message_id) => message_id == id,
278 ChannelMessageId::Pending(_) => false,
279 })
280 }
281
282 /// Load all of the chat messages since a certain message id.
283 ///
284 /// For now, we always maintain a suffix of the channel's messages.
285 pub async fn load_history_since_message(
286 chat: Model<Self>,
287 message_id: u64,
288 mut cx: AsyncAppContext,
289 ) -> Option<usize> {
290 loop {
291 let step = chat
292 .update(&mut cx, |chat, cx| {
293 if let Some(first_id) = chat.first_loaded_message_id() {
294 if first_id <= message_id {
295 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
296 let message_id = ChannelMessageId::Saved(message_id);
297 cursor.seek(&message_id, Bias::Left, &());
298 return ControlFlow::Break(
299 if cursor
300 .item()
301 .map_or(false, |message| message.id == message_id)
302 {
303 Some(cursor.start().1 .0)
304 } else {
305 None
306 },
307 );
308 }
309 }
310 ControlFlow::Continue(chat.load_more_messages(cx))
311 })
312 .log_err()?;
313 match step {
314 ControlFlow::Break(ix) => return ix,
315 ControlFlow::Continue(task) => task?.await?,
316 }
317 }
318 }
319
320 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
321 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
322 if self
323 .last_acknowledged_id
324 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
325 {
326 self.rpc
327 .send(proto::AckChannelMessage {
328 channel_id: self.channel_id.0,
329 message_id: latest_message_id,
330 })
331 .ok();
332 self.last_acknowledged_id = Some(latest_message_id);
333 self.channel_store.update(cx, |store, cx| {
334 store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
335 });
336 }
337 }
338 }
339
340 async fn handle_loaded_messages(
341 this: WeakModel<Self>,
342 user_store: Model<UserStore>,
343 rpc: Arc<Client>,
344 proto_messages: Vec<proto::ChannelMessage>,
345 loaded_all_messages: bool,
346 cx: &mut AsyncAppContext,
347 ) -> Result<()> {
348 let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?;
349
350 let first_loaded_message_id = loaded_messages.first().map(|m| m.id);
351 let loaded_message_ids = this.update(cx, |this, _| {
352 let mut loaded_message_ids: HashSet<u64> = HashSet::default();
353 for message in loaded_messages.iter() {
354 if let Some(saved_message_id) = message.id.into() {
355 loaded_message_ids.insert(saved_message_id);
356 }
357 }
358 for message in this.messages.iter() {
359 if let Some(saved_message_id) = message.id.into() {
360 loaded_message_ids.insert(saved_message_id);
361 }
362 }
363 loaded_message_ids
364 })?;
365
366 let missing_ancestors = loaded_messages
367 .iter()
368 .filter_map(|message| {
369 if let Some(ancestor_id) = message.reply_to_message_id {
370 if !loaded_message_ids.contains(&ancestor_id) {
371 return Some(ancestor_id);
372 }
373 }
374 None
375 })
376 .collect::<Vec<_>>();
377
378 let loaded_ancestors = if missing_ancestors.is_empty() {
379 None
380 } else {
381 let response = rpc
382 .request(proto::GetChannelMessagesById {
383 message_ids: missing_ancestors,
384 })
385 .await?;
386 Some(messages_from_proto(response.messages, &user_store, cx).await?)
387 };
388 this.update(cx, |this, cx| {
389 this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into());
390 this.loaded_all_messages = loaded_all_messages;
391 this.insert_messages(loaded_messages, cx);
392 if let Some(loaded_ancestors) = loaded_ancestors {
393 this.insert_messages(loaded_ancestors, cx);
394 }
395 })?;
396
397 Ok(())
398 }
399
400 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
401 let user_store = self.user_store.clone();
402 let rpc = self.rpc.clone();
403 let channel_id = self.channel_id;
404 cx.spawn(move |this, mut cx| {
405 async move {
406 let response = rpc
407 .request(proto::JoinChannelChat {
408 channel_id: channel_id.0,
409 })
410 .await?;
411 Self::handle_loaded_messages(
412 this.clone(),
413 user_store.clone(),
414 rpc.clone(),
415 response.messages,
416 response.done,
417 &mut cx,
418 )
419 .await?;
420
421 let pending_messages = this.update(&mut cx, |this, _| {
422 this.pending_messages().cloned().collect::<Vec<_>>()
423 })?;
424
425 for pending_message in pending_messages {
426 let request = rpc.request(proto::SendChannelMessage {
427 channel_id: channel_id.0,
428 body: pending_message.body,
429 mentions: mentions_to_proto(&pending_message.mentions),
430 nonce: Some(pending_message.nonce.into()),
431 reply_to_message_id: pending_message.reply_to_message_id,
432 });
433 let response = request.await?;
434 let message = ChannelMessage::from_proto(
435 response.message.ok_or_else(|| anyhow!("invalid message"))?,
436 &user_store,
437 &mut cx,
438 )
439 .await?;
440 this.update(&mut cx, |this, cx| {
441 this.insert_messages(SumTree::from_item(message, &()), cx);
442 })?;
443 }
444
445 anyhow::Ok(())
446 }
447 .log_err()
448 })
449 .detach();
450 }
451
452 pub fn message_count(&self) -> usize {
453 self.messages.summary().count
454 }
455
456 pub fn messages(&self) -> &SumTree<ChannelMessage> {
457 &self.messages
458 }
459
460 pub fn message(&self, ix: usize) -> &ChannelMessage {
461 let mut cursor = self.messages.cursor::<Count>();
462 cursor.seek(&Count(ix), Bias::Right, &());
463 cursor.item().unwrap()
464 }
465
466 pub fn acknowledge_message(&mut self, id: u64) {
467 if self.acknowledged_message_ids.insert(id) {
468 self.rpc
469 .send(proto::AckChannelMessage {
470 channel_id: self.channel_id.0,
471 message_id: id,
472 })
473 .ok();
474 }
475 }
476
477 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
478 let mut cursor = self.messages.cursor::<Count>();
479 cursor.seek(&Count(range.start), Bias::Right, &());
480 cursor.take(range.len())
481 }
482
483 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
484 let mut cursor = self.messages.cursor::<ChannelMessageId>();
485 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
486 cursor
487 }
488
489 async fn handle_message_sent(
490 this: Model<Self>,
491 message: TypedEnvelope<proto::ChannelMessageSent>,
492 _: Arc<Client>,
493 mut cx: AsyncAppContext,
494 ) -> Result<()> {
495 let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
496 let message = message
497 .payload
498 .message
499 .ok_or_else(|| anyhow!("empty message"))?;
500 let message_id = message.id;
501
502 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
503 this.update(&mut cx, |this, cx| {
504 this.insert_messages(SumTree::from_item(message, &()), cx);
505 cx.emit(ChannelChatEvent::NewMessage {
506 channel_id: this.channel_id,
507 message_id,
508 })
509 })?;
510
511 Ok(())
512 }
513
514 async fn handle_message_removed(
515 this: Model<Self>,
516 message: TypedEnvelope<proto::RemoveChannelMessage>,
517 _: Arc<Client>,
518 mut cx: AsyncAppContext,
519 ) -> Result<()> {
520 this.update(&mut cx, |this, cx| {
521 this.message_removed(message.payload.message_id, cx)
522 })?;
523 Ok(())
524 }
525
526 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
527 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
528 let nonces = messages
529 .cursor::<()>()
530 .map(|m| m.nonce)
531 .collect::<HashSet<_>>();
532
533 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
534 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
535 let start_ix = old_cursor.start().1 .0;
536 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
537 let removed_count = removed_messages.summary().count;
538 let new_count = messages.summary().count;
539 let end_ix = start_ix + removed_count;
540
541 new_messages.append(messages, &());
542
543 let mut ranges = Vec::<Range<usize>>::new();
544 if new_messages.last().unwrap().is_pending() {
545 new_messages.append(old_cursor.suffix(&()), &());
546 } else {
547 new_messages.append(
548 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
549 &(),
550 );
551
552 while let Some(message) = old_cursor.item() {
553 let message_ix = old_cursor.start().1 .0;
554 if nonces.contains(&message.nonce) {
555 if ranges.last().map_or(false, |r| r.end == message_ix) {
556 ranges.last_mut().unwrap().end += 1;
557 } else {
558 ranges.push(message_ix..message_ix + 1);
559 }
560 } else {
561 new_messages.push(message.clone(), &());
562 }
563 old_cursor.next(&());
564 }
565 }
566
567 drop(old_cursor);
568 self.messages = new_messages;
569
570 for range in ranges.into_iter().rev() {
571 cx.emit(ChannelChatEvent::MessagesUpdated {
572 old_range: range,
573 new_count: 0,
574 });
575 }
576 cx.emit(ChannelChatEvent::MessagesUpdated {
577 old_range: start_ix..end_ix,
578 new_count,
579 });
580
581 cx.notify();
582 }
583 }
584
585 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
586 let mut cursor = self.messages.cursor::<ChannelMessageId>();
587 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
588 if let Some(item) = cursor.item() {
589 if item.id == ChannelMessageId::Saved(id) {
590 let ix = messages.summary().count;
591 cursor.next(&());
592 messages.append(cursor.suffix(&()), &());
593 drop(cursor);
594 self.messages = messages;
595 cx.emit(ChannelChatEvent::MessagesUpdated {
596 old_range: ix..ix + 1,
597 new_count: 0,
598 });
599 }
600 }
601 }
602}
603
604async fn messages_from_proto(
605 proto_messages: Vec<proto::ChannelMessage>,
606 user_store: &Model<UserStore>,
607 cx: &mut AsyncAppContext,
608) -> Result<SumTree<ChannelMessage>> {
609 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
610 let mut result = SumTree::new();
611 result.extend(messages, &());
612 Ok(result)
613}
614
615impl ChannelMessage {
616 pub async fn from_proto(
617 message: proto::ChannelMessage,
618 user_store: &Model<UserStore>,
619 cx: &mut AsyncAppContext,
620 ) -> Result<Self> {
621 let sender = user_store
622 .update(cx, |user_store, cx| {
623 user_store.get_user(message.sender_id, cx)
624 })?
625 .await?;
626 Ok(ChannelMessage {
627 id: ChannelMessageId::Saved(message.id),
628 body: message.body,
629 mentions: message
630 .mentions
631 .into_iter()
632 .filter_map(|mention| {
633 let range = mention.range?;
634 Some((range.start as usize..range.end as usize, mention.user_id))
635 })
636 .collect(),
637 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
638 sender,
639 nonce: message
640 .nonce
641 .ok_or_else(|| anyhow!("nonce is required"))?
642 .into(),
643 reply_to_message_id: message.reply_to_message_id,
644 })
645 }
646
647 pub fn is_pending(&self) -> bool {
648 matches!(self.id, ChannelMessageId::Pending(_))
649 }
650
651 pub async fn from_proto_vec(
652 proto_messages: Vec<proto::ChannelMessage>,
653 user_store: &Model<UserStore>,
654 cx: &mut AsyncAppContext,
655 ) -> Result<Vec<Self>> {
656 let unique_user_ids = proto_messages
657 .iter()
658 .map(|m| m.sender_id)
659 .collect::<HashSet<_>>()
660 .into_iter()
661 .collect();
662 user_store
663 .update(cx, |user_store, cx| {
664 user_store.get_users(unique_user_ids, cx)
665 })?
666 .await?;
667
668 let mut messages = Vec::with_capacity(proto_messages.len());
669 for message in proto_messages {
670 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
671 }
672 Ok(messages)
673 }
674}
675
676pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
677 mentions
678 .iter()
679 .map(|(range, user_id)| proto::ChatMention {
680 range: Some(proto::Range {
681 start: range.start as u64,
682 end: range.end as u64,
683 }),
684 user_id: *user_id,
685 })
686 .collect()
687}
688
689impl sum_tree::Item for ChannelMessage {
690 type Summary = ChannelMessageSummary;
691
692 fn summary(&self) -> Self::Summary {
693 ChannelMessageSummary {
694 max_id: self.id,
695 count: 1,
696 }
697 }
698}
699
700impl Default for ChannelMessageId {
701 fn default() -> Self {
702 Self::Saved(0)
703 }
704}
705
706impl sum_tree::Summary for ChannelMessageSummary {
707 type Context = ();
708
709 fn add_summary(&mut self, summary: &Self, _: &()) {
710 self.max_id = summary.max_id;
711 self.count += summary.count;
712 }
713}
714
715impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
716 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
717 debug_assert!(summary.max_id > *self);
718 *self = summary.max_id;
719 }
720}
721
722impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
723 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
724 self.0 += summary.count;
725 }
726}
727
728impl<'a> From<&'a str> for MessageParams {
729 fn from(value: &'a str) -> Self {
730 Self {
731 text: value.into(),
732 mentions: Vec::new(),
733 reply_to_message_id: None,
734 }
735 }
736}