channel_chat.rs

  1use crate::{Channel, ChannelId, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope, UserId,
  7};
  8use collections::HashSet;
  9use futures::lock::Mutex;
 10use gpui::{
 11    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
 12};
 13use rand::prelude::*;
 14use std::{
 15    ops::{ControlFlow, Range},
 16    sync::Arc,
 17};
 18use sum_tree::{Bias, SumTree};
 19use time::OffsetDateTime;
 20use util::{post_inc, ResultExt as _, TryFutureExt};
 21
 22pub struct ChannelChat {
 23    pub channel_id: ChannelId,
 24    messages: SumTree<ChannelMessage>,
 25    acknowledged_message_ids: HashSet<u64>,
 26    channel_store: Model<ChannelStore>,
 27    loaded_all_messages: bool,
 28    last_acknowledged_id: Option<u64>,
 29    next_pending_message_id: usize,
 30    first_loaded_message_id: Option<u64>,
 31    user_store: Model<UserStore>,
 32    rpc: Arc<Client>,
 33    outgoing_messages_lock: Arc<Mutex<()>>,
 34    rng: StdRng,
 35    _subscription: Subscription,
 36}
 37
 38#[derive(Debug, PartialEq, Eq)]
 39pub struct MessageParams {
 40    pub text: String,
 41    pub mentions: Vec<(Range<usize>, UserId)>,
 42    pub reply_to_message_id: Option<u64>,
 43}
 44
 45#[derive(Clone, Debug)]
 46pub struct ChannelMessage {
 47    pub id: ChannelMessageId,
 48    pub body: String,
 49    pub timestamp: OffsetDateTime,
 50    pub sender: Arc<User>,
 51    pub nonce: u128,
 52    pub mentions: Vec<(Range<usize>, UserId)>,
 53    pub reply_to_message_id: Option<u64>,
 54}
 55
 56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 57pub enum ChannelMessageId {
 58    Saved(u64),
 59    Pending(usize),
 60}
 61
 62impl Into<Option<u64>> for ChannelMessageId {
 63    fn into(self) -> Option<u64> {
 64        match self {
 65            ChannelMessageId::Saved(id) => Some(id),
 66            ChannelMessageId::Pending(_) => None,
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Default)]
 72pub struct ChannelMessageSummary {
 73    max_id: ChannelMessageId,
 74    count: usize,
 75}
 76
 77#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 78struct Count(usize);
 79
 80#[derive(Clone, Debug, PartialEq)]
 81pub enum ChannelChatEvent {
 82    MessagesUpdated {
 83        old_range: Range<usize>,
 84        new_count: usize,
 85    },
 86    NewMessage {
 87        channel_id: ChannelId,
 88        message_id: u64,
 89    },
 90}
 91
 92impl EventEmitter<ChannelChatEvent> for ChannelChat {}
 93pub fn init(client: &Arc<Client>) {
 94    client.add_model_message_handler(ChannelChat::handle_message_sent);
 95    client.add_model_message_handler(ChannelChat::handle_message_removed);
 96}
 97
 98impl ChannelChat {
 99    pub async fn new(
100        channel: Arc<Channel>,
101        channel_store: Model<ChannelStore>,
102        user_store: Model<UserStore>,
103        client: Arc<Client>,
104        mut cx: AsyncAppContext,
105    ) -> Result<Model<Self>> {
106        let channel_id = channel.id;
107        let subscription = client.subscribe_to_entity(channel_id).unwrap();
108
109        let response = client
110            .request(proto::JoinChannelChat { channel_id })
111            .await?;
112
113        let handle = cx.new_model(|cx| {
114            cx.on_release(Self::release).detach();
115            Self {
116                channel_id: channel.id,
117                user_store: user_store.clone(),
118                channel_store,
119                rpc: client.clone(),
120                outgoing_messages_lock: Default::default(),
121                messages: Default::default(),
122                acknowledged_message_ids: Default::default(),
123                loaded_all_messages: false,
124                next_pending_message_id: 0,
125                last_acknowledged_id: None,
126                rng: StdRng::from_entropy(),
127                first_loaded_message_id: None,
128                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
129            }
130        })?;
131        Self::handle_loaded_messages(
132            handle.downgrade(),
133            user_store,
134            client,
135            response.messages,
136            response.done,
137            &mut cx,
138        )
139        .await?;
140        Ok(handle)
141    }
142
143    fn release(&mut self, _: &mut AppContext) {
144        self.rpc
145            .send(proto::LeaveChannelChat {
146                channel_id: self.channel_id,
147            })
148            .log_err();
149    }
150
151    pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
152        self.channel_store
153            .read(cx)
154            .channel_for_id(self.channel_id)
155            .cloned()
156    }
157
158    pub fn client(&self) -> &Arc<Client> {
159        &self.rpc
160    }
161
162    pub fn send_message(
163        &mut self,
164        message: MessageParams,
165        cx: &mut ModelContext<Self>,
166    ) -> Result<Task<Result<u64>>> {
167        if message.text.trim().is_empty() {
168            Err(anyhow!("message body can't be empty"))?;
169        }
170
171        let current_user = self
172            .user_store
173            .read(cx)
174            .current_user()
175            .ok_or_else(|| anyhow!("current_user is not present"))?;
176
177        let channel_id = self.channel_id;
178        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
179        let nonce = self.rng.gen();
180        self.insert_messages(
181            SumTree::from_item(
182                ChannelMessage {
183                    id: pending_id,
184                    body: message.text.clone(),
185                    sender: current_user,
186                    timestamp: OffsetDateTime::now_utc(),
187                    mentions: message.mentions.clone(),
188                    nonce,
189                    reply_to_message_id: message.reply_to_message_id,
190                },
191                &(),
192            ),
193            cx,
194        );
195        let user_store = self.user_store.clone();
196        let rpc = self.rpc.clone();
197        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
198
199        // todo - handle messages that fail to send (e.g. >1024 chars)
200        Ok(cx.spawn(move |this, mut cx| async move {
201            let outgoing_message_guard = outgoing_messages_lock.lock().await;
202            let request = rpc.request(proto::SendChannelMessage {
203                channel_id,
204                body: message.text,
205                nonce: Some(nonce.into()),
206                mentions: mentions_to_proto(&message.mentions),
207                reply_to_message_id: message.reply_to_message_id,
208            });
209            let response = request.await?;
210            drop(outgoing_message_guard);
211            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
212            let id = response.id;
213            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
214            this.update(&mut cx, |this, cx| {
215                this.insert_messages(SumTree::from_item(message, &()), cx);
216            })?;
217            Ok(id)
218        }))
219    }
220
221    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
222        let response = self.rpc.request(proto::RemoveChannelMessage {
223            channel_id: self.channel_id,
224            message_id: id,
225        });
226        cx.spawn(move |this, mut cx| async move {
227            response.await?;
228            this.update(&mut cx, |this, cx| {
229                this.message_removed(id, cx);
230            })?;
231            Ok(())
232        })
233    }
234
235    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
236        if self.loaded_all_messages {
237            return None;
238        }
239
240        let rpc = self.rpc.clone();
241        let user_store = self.user_store.clone();
242        let channel_id = self.channel_id;
243        let before_message_id = self.first_loaded_message_id()?;
244        Some(cx.spawn(move |this, mut cx| {
245            async move {
246                let response = rpc
247                    .request(proto::GetChannelMessages {
248                        channel_id,
249                        before_message_id,
250                    })
251                    .await?;
252                Self::handle_loaded_messages(
253                    this,
254                    user_store,
255                    rpc,
256                    response.messages,
257                    response.done,
258                    &mut cx,
259                )
260                .await?;
261
262                anyhow::Ok(())
263            }
264            .log_err()
265        }))
266    }
267
268    pub fn first_loaded_message_id(&mut self) -> Option<u64> {
269        self.first_loaded_message_id
270    }
271
272    /// Load a message by its id, if it's already stored locally.
273    pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> {
274        self.messages.iter().find(|message| match message.id {
275            ChannelMessageId::Saved(message_id) => message_id == id,
276            ChannelMessageId::Pending(_) => false,
277        })
278    }
279
280    /// Load all of the chat messages since a certain message id.
281    ///
282    /// For now, we always maintain a suffix of the channel's messages.
283    pub async fn load_history_since_message(
284        chat: Model<Self>,
285        message_id: u64,
286        mut cx: AsyncAppContext,
287    ) -> Option<usize> {
288        loop {
289            let step = chat
290                .update(&mut cx, |chat, cx| {
291                    if let Some(first_id) = chat.first_loaded_message_id() {
292                        if first_id <= message_id {
293                            let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
294                            let message_id = ChannelMessageId::Saved(message_id);
295                            cursor.seek(&message_id, Bias::Left, &());
296                            return ControlFlow::Break(
297                                if cursor
298                                    .item()
299                                    .map_or(false, |message| message.id == message_id)
300                                {
301                                    Some(cursor.start().1 .0)
302                                } else {
303                                    None
304                                },
305                            );
306                        }
307                    }
308                    ControlFlow::Continue(chat.load_more_messages(cx))
309                })
310                .log_err()?;
311            match step {
312                ControlFlow::Break(ix) => return ix,
313                ControlFlow::Continue(task) => task?.await?,
314            }
315        }
316    }
317
318    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
319        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
320            if self
321                .last_acknowledged_id
322                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
323            {
324                self.rpc
325                    .send(proto::AckChannelMessage {
326                        channel_id: self.channel_id,
327                        message_id: latest_message_id,
328                    })
329                    .ok();
330                self.last_acknowledged_id = Some(latest_message_id);
331                self.channel_store.update(cx, |store, cx| {
332                    store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
333                });
334            }
335        }
336    }
337
338    async fn handle_loaded_messages(
339        this: WeakModel<Self>,
340        user_store: Model<UserStore>,
341        rpc: Arc<Client>,
342        proto_messages: Vec<proto::ChannelMessage>,
343        loaded_all_messages: bool,
344        cx: &mut AsyncAppContext,
345    ) -> Result<()> {
346        let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?;
347
348        let first_loaded_message_id = loaded_messages.first().map(|m| m.id);
349        let loaded_message_ids = this.update(cx, |this, _| {
350            let mut loaded_message_ids: HashSet<u64> = HashSet::default();
351            for message in loaded_messages.iter() {
352                if let Some(saved_message_id) = message.id.into() {
353                    loaded_message_ids.insert(saved_message_id);
354                }
355            }
356            for message in this.messages.iter() {
357                if let Some(saved_message_id) = message.id.into() {
358                    loaded_message_ids.insert(saved_message_id);
359                }
360            }
361            loaded_message_ids
362        })?;
363
364        let missing_ancestors = loaded_messages
365            .iter()
366            .filter_map(|message| {
367                if let Some(ancestor_id) = message.reply_to_message_id {
368                    if !loaded_message_ids.contains(&ancestor_id) {
369                        return Some(ancestor_id);
370                    }
371                }
372                None
373            })
374            .collect::<Vec<_>>();
375
376        let loaded_ancestors = if missing_ancestors.is_empty() {
377            None
378        } else {
379            let response = rpc
380                .request(proto::GetChannelMessagesById {
381                    message_ids: missing_ancestors,
382                })
383                .await?;
384            Some(messages_from_proto(response.messages, &user_store, cx).await?)
385        };
386        this.update(cx, |this, cx| {
387            this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into());
388            this.loaded_all_messages = loaded_all_messages;
389            this.insert_messages(loaded_messages, cx);
390            if let Some(loaded_ancestors) = loaded_ancestors {
391                this.insert_messages(loaded_ancestors, cx);
392            }
393        })?;
394
395        Ok(())
396    }
397
398    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
399        let user_store = self.user_store.clone();
400        let rpc = self.rpc.clone();
401        let channel_id = self.channel_id;
402        cx.spawn(move |this, mut cx| {
403            async move {
404                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
405                Self::handle_loaded_messages(
406                    this.clone(),
407                    user_store.clone(),
408                    rpc.clone(),
409                    response.messages,
410                    response.done,
411                    &mut cx,
412                )
413                .await?;
414
415                let pending_messages = this.update(&mut cx, |this, _| {
416                    this.pending_messages().cloned().collect::<Vec<_>>()
417                })?;
418
419                for pending_message in pending_messages {
420                    let request = rpc.request(proto::SendChannelMessage {
421                        channel_id,
422                        body: pending_message.body,
423                        mentions: mentions_to_proto(&pending_message.mentions),
424                        nonce: Some(pending_message.nonce.into()),
425                        reply_to_message_id: pending_message.reply_to_message_id,
426                    });
427                    let response = request.await?;
428                    let message = ChannelMessage::from_proto(
429                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
430                        &user_store,
431                        &mut cx,
432                    )
433                    .await?;
434                    this.update(&mut cx, |this, cx| {
435                        this.insert_messages(SumTree::from_item(message, &()), cx);
436                    })?;
437                }
438
439                anyhow::Ok(())
440            }
441            .log_err()
442        })
443        .detach();
444    }
445
446    pub fn message_count(&self) -> usize {
447        self.messages.summary().count
448    }
449
450    pub fn messages(&self) -> &SumTree<ChannelMessage> {
451        &self.messages
452    }
453
454    pub fn message(&self, ix: usize) -> &ChannelMessage {
455        let mut cursor = self.messages.cursor::<Count>();
456        cursor.seek(&Count(ix), Bias::Right, &());
457        cursor.item().unwrap()
458    }
459
460    pub fn acknowledge_message(&mut self, id: u64) {
461        if self.acknowledged_message_ids.insert(id) {
462            self.rpc
463                .send(proto::AckChannelMessage {
464                    channel_id: self.channel_id,
465                    message_id: id,
466                })
467                .ok();
468        }
469    }
470
471    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
472        let mut cursor = self.messages.cursor::<Count>();
473        cursor.seek(&Count(range.start), Bias::Right, &());
474        cursor.take(range.len())
475    }
476
477    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
478        let mut cursor = self.messages.cursor::<ChannelMessageId>();
479        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
480        cursor
481    }
482
483    async fn handle_message_sent(
484        this: Model<Self>,
485        message: TypedEnvelope<proto::ChannelMessageSent>,
486        _: Arc<Client>,
487        mut cx: AsyncAppContext,
488    ) -> Result<()> {
489        let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
490        let message = message
491            .payload
492            .message
493            .ok_or_else(|| anyhow!("empty message"))?;
494        let message_id = message.id;
495
496        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
497        this.update(&mut cx, |this, cx| {
498            this.insert_messages(SumTree::from_item(message, &()), cx);
499            cx.emit(ChannelChatEvent::NewMessage {
500                channel_id: this.channel_id,
501                message_id,
502            })
503        })?;
504
505        Ok(())
506    }
507
508    async fn handle_message_removed(
509        this: Model<Self>,
510        message: TypedEnvelope<proto::RemoveChannelMessage>,
511        _: Arc<Client>,
512        mut cx: AsyncAppContext,
513    ) -> Result<()> {
514        this.update(&mut cx, |this, cx| {
515            this.message_removed(message.payload.message_id, cx)
516        })?;
517        Ok(())
518    }
519
520    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
521        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
522            let nonces = messages
523                .cursor::<()>()
524                .map(|m| m.nonce)
525                .collect::<HashSet<_>>();
526
527            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
528            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
529            let start_ix = old_cursor.start().1 .0;
530            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
531            let removed_count = removed_messages.summary().count;
532            let new_count = messages.summary().count;
533            let end_ix = start_ix + removed_count;
534
535            new_messages.append(messages, &());
536
537            let mut ranges = Vec::<Range<usize>>::new();
538            if new_messages.last().unwrap().is_pending() {
539                new_messages.append(old_cursor.suffix(&()), &());
540            } else {
541                new_messages.append(
542                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
543                    &(),
544                );
545
546                while let Some(message) = old_cursor.item() {
547                    let message_ix = old_cursor.start().1 .0;
548                    if nonces.contains(&message.nonce) {
549                        if ranges.last().map_or(false, |r| r.end == message_ix) {
550                            ranges.last_mut().unwrap().end += 1;
551                        } else {
552                            ranges.push(message_ix..message_ix + 1);
553                        }
554                    } else {
555                        new_messages.push(message.clone(), &());
556                    }
557                    old_cursor.next(&());
558                }
559            }
560
561            drop(old_cursor);
562            self.messages = new_messages;
563
564            for range in ranges.into_iter().rev() {
565                cx.emit(ChannelChatEvent::MessagesUpdated {
566                    old_range: range,
567                    new_count: 0,
568                });
569            }
570            cx.emit(ChannelChatEvent::MessagesUpdated {
571                old_range: start_ix..end_ix,
572                new_count,
573            });
574
575            cx.notify();
576        }
577    }
578
579    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
580        let mut cursor = self.messages.cursor::<ChannelMessageId>();
581        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
582        if let Some(item) = cursor.item() {
583            if item.id == ChannelMessageId::Saved(id) {
584                let ix = messages.summary().count;
585                cursor.next(&());
586                messages.append(cursor.suffix(&()), &());
587                drop(cursor);
588                self.messages = messages;
589                cx.emit(ChannelChatEvent::MessagesUpdated {
590                    old_range: ix..ix + 1,
591                    new_count: 0,
592                });
593            }
594        }
595    }
596}
597
598async fn messages_from_proto(
599    proto_messages: Vec<proto::ChannelMessage>,
600    user_store: &Model<UserStore>,
601    cx: &mut AsyncAppContext,
602) -> Result<SumTree<ChannelMessage>> {
603    let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
604    let mut result = SumTree::new();
605    result.extend(messages, &());
606    Ok(result)
607}
608
609impl ChannelMessage {
610    pub async fn from_proto(
611        message: proto::ChannelMessage,
612        user_store: &Model<UserStore>,
613        cx: &mut AsyncAppContext,
614    ) -> Result<Self> {
615        let sender = user_store
616            .update(cx, |user_store, cx| {
617                user_store.get_user(message.sender_id, cx)
618            })?
619            .await?;
620        Ok(ChannelMessage {
621            id: ChannelMessageId::Saved(message.id),
622            body: message.body,
623            mentions: message
624                .mentions
625                .into_iter()
626                .filter_map(|mention| {
627                    let range = mention.range?;
628                    Some((range.start as usize..range.end as usize, mention.user_id))
629                })
630                .collect(),
631            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
632            sender,
633            nonce: message
634                .nonce
635                .ok_or_else(|| anyhow!("nonce is required"))?
636                .into(),
637            reply_to_message_id: message.reply_to_message_id,
638        })
639    }
640
641    pub fn is_pending(&self) -> bool {
642        matches!(self.id, ChannelMessageId::Pending(_))
643    }
644
645    pub async fn from_proto_vec(
646        proto_messages: Vec<proto::ChannelMessage>,
647        user_store: &Model<UserStore>,
648        cx: &mut AsyncAppContext,
649    ) -> Result<Vec<Self>> {
650        let unique_user_ids = proto_messages
651            .iter()
652            .map(|m| m.sender_id)
653            .collect::<HashSet<_>>()
654            .into_iter()
655            .collect();
656        user_store
657            .update(cx, |user_store, cx| {
658                user_store.get_users(unique_user_ids, cx)
659            })?
660            .await?;
661
662        let mut messages = Vec::with_capacity(proto_messages.len());
663        for message in proto_messages {
664            messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
665        }
666        Ok(messages)
667    }
668}
669
670pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
671    mentions
672        .iter()
673        .map(|(range, user_id)| proto::ChatMention {
674            range: Some(proto::Range {
675                start: range.start as u64,
676                end: range.end as u64,
677            }),
678            user_id: *user_id as u64,
679        })
680        .collect()
681}
682
683impl sum_tree::Item for ChannelMessage {
684    type Summary = ChannelMessageSummary;
685
686    fn summary(&self) -> Self::Summary {
687        ChannelMessageSummary {
688            max_id: self.id,
689            count: 1,
690        }
691    }
692}
693
694impl Default for ChannelMessageId {
695    fn default() -> Self {
696        Self::Saved(0)
697    }
698}
699
700impl sum_tree::Summary for ChannelMessageSummary {
701    type Context = ();
702
703    fn add_summary(&mut self, summary: &Self, _: &()) {
704        self.max_id = summary.max_id;
705        self.count += summary.count;
706    }
707}
708
709impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
710    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
711        debug_assert!(summary.max_id > *self);
712        *self = summary.max_id;
713    }
714}
715
716impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
717    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
718        self.0 += summary.count;
719    }
720}
721
722impl<'a> From<&'a str> for MessageParams {
723    fn from(value: &'a str) -> Self {
724        Self {
725            text: value.into(),
726            mentions: Vec::new(),
727            reply_to_message_id: None,
728        }
729    }
730}