channel_chat.rs

  1use crate::{Channel, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    ChannelId, Client, Subscription, TypedEnvelope, UserId,
  7};
  8use collections::HashSet;
  9use futures::lock::Mutex;
 10use gpui::{
 11    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
 12};
 13use rand::prelude::*;
 14use std::{
 15    ops::{ControlFlow, Range},
 16    sync::Arc,
 17};
 18use sum_tree::{Bias, SumTree};
 19use time::OffsetDateTime;
 20use util::{post_inc, ResultExt as _, TryFutureExt};
 21
 22pub struct ChannelChat {
 23    pub channel_id: ChannelId,
 24    messages: SumTree<ChannelMessage>,
 25    acknowledged_message_ids: HashSet<u64>,
 26    channel_store: Model<ChannelStore>,
 27    loaded_all_messages: bool,
 28    last_acknowledged_id: Option<u64>,
 29    next_pending_message_id: usize,
 30    first_loaded_message_id: Option<u64>,
 31    user_store: Model<UserStore>,
 32    rpc: Arc<Client>,
 33    outgoing_messages_lock: Arc<Mutex<()>>,
 34    rng: StdRng,
 35    _subscription: Subscription,
 36}
 37
 38#[derive(Debug, PartialEq, Eq)]
 39pub struct MessageParams {
 40    pub text: String,
 41    pub mentions: Vec<(Range<usize>, UserId)>,
 42    pub reply_to_message_id: Option<u64>,
 43}
 44
 45#[derive(Clone, Debug)]
 46pub struct ChannelMessage {
 47    pub id: ChannelMessageId,
 48    pub body: String,
 49    pub timestamp: OffsetDateTime,
 50    pub sender: Arc<User>,
 51    pub nonce: u128,
 52    pub mentions: Vec<(Range<usize>, UserId)>,
 53    pub reply_to_message_id: Option<u64>,
 54}
 55
 56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 57pub enum ChannelMessageId {
 58    Saved(u64),
 59    Pending(usize),
 60}
 61
 62impl Into<Option<u64>> for ChannelMessageId {
 63    fn into(self) -> Option<u64> {
 64        match self {
 65            ChannelMessageId::Saved(id) => Some(id),
 66            ChannelMessageId::Pending(_) => None,
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Default)]
 72pub struct ChannelMessageSummary {
 73    max_id: ChannelMessageId,
 74    count: usize,
 75}
 76
 77#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 78struct Count(usize);
 79
 80#[derive(Clone, Debug, PartialEq)]
 81pub enum ChannelChatEvent {
 82    MessagesUpdated {
 83        old_range: Range<usize>,
 84        new_count: usize,
 85    },
 86    NewMessage {
 87        channel_id: ChannelId,
 88        message_id: u64,
 89    },
 90}
 91
 92impl EventEmitter<ChannelChatEvent> for ChannelChat {}
 93pub fn init(client: &Arc<Client>) {
 94    client.add_model_message_handler(ChannelChat::handle_message_sent);
 95    client.add_model_message_handler(ChannelChat::handle_message_removed);
 96}
 97
 98impl ChannelChat {
 99    pub async fn new(
100        channel: Arc<Channel>,
101        channel_store: Model<ChannelStore>,
102        user_store: Model<UserStore>,
103        client: Arc<Client>,
104        mut cx: AsyncAppContext,
105    ) -> Result<Model<Self>> {
106        let channel_id = channel.id;
107        let subscription = client.subscribe_to_entity(channel_id.0).unwrap();
108
109        let response = client
110            .request(proto::JoinChannelChat {
111                channel_id: channel_id.0,
112            })
113            .await?;
114
115        let handle = cx.new_model(|cx| {
116            cx.on_release(Self::release).detach();
117            Self {
118                channel_id: channel.id,
119                user_store: user_store.clone(),
120                channel_store,
121                rpc: client.clone(),
122                outgoing_messages_lock: Default::default(),
123                messages: Default::default(),
124                acknowledged_message_ids: Default::default(),
125                loaded_all_messages: false,
126                next_pending_message_id: 0,
127                last_acknowledged_id: None,
128                rng: StdRng::from_entropy(),
129                first_loaded_message_id: None,
130                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
131            }
132        })?;
133        Self::handle_loaded_messages(
134            handle.downgrade(),
135            user_store,
136            client,
137            response.messages,
138            response.done,
139            &mut cx,
140        )
141        .await?;
142        Ok(handle)
143    }
144
145    fn release(&mut self, _: &mut AppContext) {
146        self.rpc
147            .send(proto::LeaveChannelChat {
148                channel_id: self.channel_id.0,
149            })
150            .log_err();
151    }
152
153    pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
154        self.channel_store
155            .read(cx)
156            .channel_for_id(self.channel_id)
157            .cloned()
158    }
159
160    pub fn client(&self) -> &Arc<Client> {
161        &self.rpc
162    }
163
164    pub fn send_message(
165        &mut self,
166        message: MessageParams,
167        cx: &mut ModelContext<Self>,
168    ) -> Result<Task<Result<u64>>> {
169        if message.text.trim().is_empty() {
170            Err(anyhow!("message body can't be empty"))?;
171        }
172
173        let current_user = self
174            .user_store
175            .read(cx)
176            .current_user()
177            .ok_or_else(|| anyhow!("current_user is not present"))?;
178
179        let channel_id = self.channel_id;
180        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
181        let nonce = self.rng.gen();
182        self.insert_messages(
183            SumTree::from_item(
184                ChannelMessage {
185                    id: pending_id,
186                    body: message.text.clone(),
187                    sender: current_user,
188                    timestamp: OffsetDateTime::now_utc(),
189                    mentions: message.mentions.clone(),
190                    nonce,
191                    reply_to_message_id: message.reply_to_message_id,
192                },
193                &(),
194            ),
195            cx,
196        );
197        let user_store = self.user_store.clone();
198        let rpc = self.rpc.clone();
199        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
200
201        // todo - handle messages that fail to send (e.g. >1024 chars)
202        Ok(cx.spawn(move |this, mut cx| async move {
203            let outgoing_message_guard = outgoing_messages_lock.lock().await;
204            let request = rpc.request(proto::SendChannelMessage {
205                channel_id: channel_id.0,
206                body: message.text,
207                nonce: Some(nonce.into()),
208                mentions: mentions_to_proto(&message.mentions),
209                reply_to_message_id: message.reply_to_message_id,
210            });
211            let response = request.await?;
212            drop(outgoing_message_guard);
213            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
214            let id = response.id;
215            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
216            this.update(&mut cx, |this, cx| {
217                this.insert_messages(SumTree::from_item(message, &()), cx);
218            })?;
219            Ok(id)
220        }))
221    }
222
223    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
224        let response = self.rpc.request(proto::RemoveChannelMessage {
225            channel_id: self.channel_id.0,
226            message_id: id,
227        });
228        cx.spawn(move |this, mut cx| async move {
229            response.await?;
230            this.update(&mut cx, |this, cx| {
231                this.message_removed(id, cx);
232            })?;
233            Ok(())
234        })
235    }
236
237    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
238        if self.loaded_all_messages {
239            return None;
240        }
241
242        let rpc = self.rpc.clone();
243        let user_store = self.user_store.clone();
244        let channel_id = self.channel_id;
245        let before_message_id = self.first_loaded_message_id()?;
246        Some(cx.spawn(move |this, mut cx| {
247            async move {
248                let response = rpc
249                    .request(proto::GetChannelMessages {
250                        channel_id: channel_id.0,
251                        before_message_id,
252                    })
253                    .await?;
254                Self::handle_loaded_messages(
255                    this,
256                    user_store,
257                    rpc,
258                    response.messages,
259                    response.done,
260                    &mut cx,
261                )
262                .await?;
263
264                anyhow::Ok(())
265            }
266            .log_err()
267        }))
268    }
269
270    pub fn first_loaded_message_id(&mut self) -> Option<u64> {
271        self.first_loaded_message_id
272    }
273
274    /// Load a message by its id, if it's already stored locally.
275    pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> {
276        self.messages.iter().find(|message| match message.id {
277            ChannelMessageId::Saved(message_id) => message_id == id,
278            ChannelMessageId::Pending(_) => false,
279        })
280    }
281
282    /// Load all of the chat messages since a certain message id.
283    ///
284    /// For now, we always maintain a suffix of the channel's messages.
285    pub async fn load_history_since_message(
286        chat: Model<Self>,
287        message_id: u64,
288        mut cx: AsyncAppContext,
289    ) -> Option<usize> {
290        loop {
291            let step = chat
292                .update(&mut cx, |chat, cx| {
293                    if let Some(first_id) = chat.first_loaded_message_id() {
294                        if first_id <= message_id {
295                            let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
296                            let message_id = ChannelMessageId::Saved(message_id);
297                            cursor.seek(&message_id, Bias::Left, &());
298                            return ControlFlow::Break(
299                                if cursor
300                                    .item()
301                                    .map_or(false, |message| message.id == message_id)
302                                {
303                                    Some(cursor.start().1 .0)
304                                } else {
305                                    None
306                                },
307                            );
308                        }
309                    }
310                    ControlFlow::Continue(chat.load_more_messages(cx))
311                })
312                .log_err()?;
313            match step {
314                ControlFlow::Break(ix) => return ix,
315                ControlFlow::Continue(task) => task?.await?,
316            }
317        }
318    }
319
320    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
321        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
322            if self
323                .last_acknowledged_id
324                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
325            {
326                self.rpc
327                    .send(proto::AckChannelMessage {
328                        channel_id: self.channel_id.0,
329                        message_id: latest_message_id,
330                    })
331                    .ok();
332                self.last_acknowledged_id = Some(latest_message_id);
333                self.channel_store.update(cx, |store, cx| {
334                    store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
335                });
336            }
337        }
338    }
339
340    async fn handle_loaded_messages(
341        this: WeakModel<Self>,
342        user_store: Model<UserStore>,
343        rpc: Arc<Client>,
344        proto_messages: Vec<proto::ChannelMessage>,
345        loaded_all_messages: bool,
346        cx: &mut AsyncAppContext,
347    ) -> Result<()> {
348        let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?;
349
350        let first_loaded_message_id = loaded_messages.first().map(|m| m.id);
351        let loaded_message_ids = this.update(cx, |this, _| {
352            let mut loaded_message_ids: HashSet<u64> = HashSet::default();
353            for message in loaded_messages.iter() {
354                if let Some(saved_message_id) = message.id.into() {
355                    loaded_message_ids.insert(saved_message_id);
356                }
357            }
358            for message in this.messages.iter() {
359                if let Some(saved_message_id) = message.id.into() {
360                    loaded_message_ids.insert(saved_message_id);
361                }
362            }
363            loaded_message_ids
364        })?;
365
366        let missing_ancestors = loaded_messages
367            .iter()
368            .filter_map(|message| {
369                if let Some(ancestor_id) = message.reply_to_message_id {
370                    if !loaded_message_ids.contains(&ancestor_id) {
371                        return Some(ancestor_id);
372                    }
373                }
374                None
375            })
376            .collect::<Vec<_>>();
377
378        let loaded_ancestors = if missing_ancestors.is_empty() {
379            None
380        } else {
381            let response = rpc
382                .request(proto::GetChannelMessagesById {
383                    message_ids: missing_ancestors,
384                })
385                .await?;
386            Some(messages_from_proto(response.messages, &user_store, cx).await?)
387        };
388        this.update(cx, |this, cx| {
389            this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into());
390            this.loaded_all_messages = loaded_all_messages;
391            this.insert_messages(loaded_messages, cx);
392            if let Some(loaded_ancestors) = loaded_ancestors {
393                this.insert_messages(loaded_ancestors, cx);
394            }
395        })?;
396
397        Ok(())
398    }
399
400    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
401        let user_store = self.user_store.clone();
402        let rpc = self.rpc.clone();
403        let channel_id = self.channel_id;
404        cx.spawn(move |this, mut cx| {
405            async move {
406                let response = rpc
407                    .request(proto::JoinChannelChat {
408                        channel_id: channel_id.0,
409                    })
410                    .await?;
411                Self::handle_loaded_messages(
412                    this.clone(),
413                    user_store.clone(),
414                    rpc.clone(),
415                    response.messages,
416                    response.done,
417                    &mut cx,
418                )
419                .await?;
420
421                let pending_messages = this.update(&mut cx, |this, _| {
422                    this.pending_messages().cloned().collect::<Vec<_>>()
423                })?;
424
425                for pending_message in pending_messages {
426                    let request = rpc.request(proto::SendChannelMessage {
427                        channel_id: channel_id.0,
428                        body: pending_message.body,
429                        mentions: mentions_to_proto(&pending_message.mentions),
430                        nonce: Some(pending_message.nonce.into()),
431                        reply_to_message_id: pending_message.reply_to_message_id,
432                    });
433                    let response = request.await?;
434                    let message = ChannelMessage::from_proto(
435                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
436                        &user_store,
437                        &mut cx,
438                    )
439                    .await?;
440                    this.update(&mut cx, |this, cx| {
441                        this.insert_messages(SumTree::from_item(message, &()), cx);
442                    })?;
443                }
444
445                anyhow::Ok(())
446            }
447            .log_err()
448        })
449        .detach();
450    }
451
452    pub fn message_count(&self) -> usize {
453        self.messages.summary().count
454    }
455
456    pub fn messages(&self) -> &SumTree<ChannelMessage> {
457        &self.messages
458    }
459
460    pub fn message(&self, ix: usize) -> &ChannelMessage {
461        let mut cursor = self.messages.cursor::<Count>();
462        cursor.seek(&Count(ix), Bias::Right, &());
463        cursor.item().unwrap()
464    }
465
466    pub fn acknowledge_message(&mut self, id: u64) {
467        if self.acknowledged_message_ids.insert(id) {
468            self.rpc
469                .send(proto::AckChannelMessage {
470                    channel_id: self.channel_id.0,
471                    message_id: id,
472                })
473                .ok();
474        }
475    }
476
477    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
478        let mut cursor = self.messages.cursor::<Count>();
479        cursor.seek(&Count(range.start), Bias::Right, &());
480        cursor.take(range.len())
481    }
482
483    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
484        let mut cursor = self.messages.cursor::<ChannelMessageId>();
485        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
486        cursor
487    }
488
489    async fn handle_message_sent(
490        this: Model<Self>,
491        message: TypedEnvelope<proto::ChannelMessageSent>,
492        _: Arc<Client>,
493        mut cx: AsyncAppContext,
494    ) -> Result<()> {
495        let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
496        let message = message
497            .payload
498            .message
499            .ok_or_else(|| anyhow!("empty message"))?;
500        let message_id = message.id;
501
502        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
503        this.update(&mut cx, |this, cx| {
504            this.insert_messages(SumTree::from_item(message, &()), cx);
505            cx.emit(ChannelChatEvent::NewMessage {
506                channel_id: this.channel_id,
507                message_id,
508            })
509        })?;
510
511        Ok(())
512    }
513
514    async fn handle_message_removed(
515        this: Model<Self>,
516        message: TypedEnvelope<proto::RemoveChannelMessage>,
517        _: Arc<Client>,
518        mut cx: AsyncAppContext,
519    ) -> Result<()> {
520        this.update(&mut cx, |this, cx| {
521            this.message_removed(message.payload.message_id, cx)
522        })?;
523        Ok(())
524    }
525
526    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
527        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
528            let nonces = messages
529                .cursor::<()>()
530                .map(|m| m.nonce)
531                .collect::<HashSet<_>>();
532
533            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
534            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
535            let start_ix = old_cursor.start().1 .0;
536            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
537            let removed_count = removed_messages.summary().count;
538            let new_count = messages.summary().count;
539            let end_ix = start_ix + removed_count;
540
541            new_messages.append(messages, &());
542
543            let mut ranges = Vec::<Range<usize>>::new();
544            if new_messages.last().unwrap().is_pending() {
545                new_messages.append(old_cursor.suffix(&()), &());
546            } else {
547                new_messages.append(
548                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
549                    &(),
550                );
551
552                while let Some(message) = old_cursor.item() {
553                    let message_ix = old_cursor.start().1 .0;
554                    if nonces.contains(&message.nonce) {
555                        if ranges.last().map_or(false, |r| r.end == message_ix) {
556                            ranges.last_mut().unwrap().end += 1;
557                        } else {
558                            ranges.push(message_ix..message_ix + 1);
559                        }
560                    } else {
561                        new_messages.push(message.clone(), &());
562                    }
563                    old_cursor.next(&());
564                }
565            }
566
567            drop(old_cursor);
568            self.messages = new_messages;
569
570            for range in ranges.into_iter().rev() {
571                cx.emit(ChannelChatEvent::MessagesUpdated {
572                    old_range: range,
573                    new_count: 0,
574                });
575            }
576            cx.emit(ChannelChatEvent::MessagesUpdated {
577                old_range: start_ix..end_ix,
578                new_count,
579            });
580
581            cx.notify();
582        }
583    }
584
585    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
586        let mut cursor = self.messages.cursor::<ChannelMessageId>();
587        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
588        if let Some(item) = cursor.item() {
589            if item.id == ChannelMessageId::Saved(id) {
590                let ix = messages.summary().count;
591                cursor.next(&());
592                messages.append(cursor.suffix(&()), &());
593                drop(cursor);
594                self.messages = messages;
595                cx.emit(ChannelChatEvent::MessagesUpdated {
596                    old_range: ix..ix + 1,
597                    new_count: 0,
598                });
599            }
600        }
601    }
602}
603
604async fn messages_from_proto(
605    proto_messages: Vec<proto::ChannelMessage>,
606    user_store: &Model<UserStore>,
607    cx: &mut AsyncAppContext,
608) -> Result<SumTree<ChannelMessage>> {
609    let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
610    let mut result = SumTree::new();
611    result.extend(messages, &());
612    Ok(result)
613}
614
615impl ChannelMessage {
616    pub async fn from_proto(
617        message: proto::ChannelMessage,
618        user_store: &Model<UserStore>,
619        cx: &mut AsyncAppContext,
620    ) -> Result<Self> {
621        let sender = user_store
622            .update(cx, |user_store, cx| {
623                user_store.get_user(message.sender_id, cx)
624            })?
625            .await?;
626        Ok(ChannelMessage {
627            id: ChannelMessageId::Saved(message.id),
628            body: message.body,
629            mentions: message
630                .mentions
631                .into_iter()
632                .filter_map(|mention| {
633                    let range = mention.range?;
634                    Some((range.start as usize..range.end as usize, mention.user_id))
635                })
636                .collect(),
637            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
638            sender,
639            nonce: message
640                .nonce
641                .ok_or_else(|| anyhow!("nonce is required"))?
642                .into(),
643            reply_to_message_id: message.reply_to_message_id,
644        })
645    }
646
647    pub fn is_pending(&self) -> bool {
648        matches!(self.id, ChannelMessageId::Pending(_))
649    }
650
651    pub async fn from_proto_vec(
652        proto_messages: Vec<proto::ChannelMessage>,
653        user_store: &Model<UserStore>,
654        cx: &mut AsyncAppContext,
655    ) -> Result<Vec<Self>> {
656        let unique_user_ids = proto_messages
657            .iter()
658            .map(|m| m.sender_id)
659            .collect::<HashSet<_>>()
660            .into_iter()
661            .collect();
662        user_store
663            .update(cx, |user_store, cx| {
664                user_store.get_users(unique_user_ids, cx)
665            })?
666            .await?;
667
668        let mut messages = Vec::with_capacity(proto_messages.len());
669        for message in proto_messages {
670            messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
671        }
672        Ok(messages)
673    }
674}
675
676pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
677    mentions
678        .iter()
679        .map(|(range, user_id)| proto::ChatMention {
680            range: Some(proto::Range {
681                start: range.start as u64,
682                end: range.end as u64,
683            }),
684            user_id: *user_id,
685        })
686        .collect()
687}
688
689impl sum_tree::Item for ChannelMessage {
690    type Summary = ChannelMessageSummary;
691
692    fn summary(&self) -> Self::Summary {
693        ChannelMessageSummary {
694            max_id: self.id,
695            count: 1,
696        }
697    }
698}
699
700impl Default for ChannelMessageId {
701    fn default() -> Self {
702        Self::Saved(0)
703    }
704}
705
706impl sum_tree::Summary for ChannelMessageSummary {
707    type Context = ();
708
709    fn add_summary(&mut self, summary: &Self, _: &()) {
710        self.max_id = summary.max_id;
711        self.count += summary.count;
712    }
713}
714
715impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
716    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
717        debug_assert!(summary.max_id > *self);
718        *self = summary.max_id;
719    }
720}
721
722impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
723    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
724        self.0 += summary.count;
725    }
726}
727
728impl<'a> From<&'a str> for MessageParams {
729    fn from(value: &'a str) -> Self {
730        Self {
731            text: value.into(),
732            mentions: Vec::new(),
733            reply_to_message_id: None,
734        }
735    }
736}