channel_chat.rs

  1use crate::{Channel, ChannelId, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope, UserId,
  7};
  8use futures::lock::Mutex;
  9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use rand::prelude::*;
 11use std::{
 12    collections::HashSet,
 13    mem,
 14    ops::{ControlFlow, Range},
 15    sync::Arc,
 16};
 17use sum_tree::{Bias, SumTree};
 18use time::OffsetDateTime;
 19use util::{post_inc, ResultExt as _, TryFutureExt};
 20
 21pub struct ChannelChat {
 22    channel: Arc<Channel>,
 23    messages: SumTree<ChannelMessage>,
 24    acknowledged_message_ids: HashSet<u64>,
 25    channel_store: ModelHandle<ChannelStore>,
 26    loaded_all_messages: bool,
 27    last_acknowledged_id: Option<u64>,
 28    next_pending_message_id: usize,
 29    user_store: ModelHandle<UserStore>,
 30    rpc: Arc<Client>,
 31    outgoing_messages_lock: Arc<Mutex<()>>,
 32    rng: StdRng,
 33    _subscription: Subscription,
 34}
 35
 36#[derive(Debug, PartialEq, Eq)]
 37pub struct MessageParams {
 38    pub text: String,
 39    pub mentions: Vec<(Range<usize>, UserId)>,
 40}
 41
 42#[derive(Clone, Debug)]
 43pub struct ChannelMessage {
 44    pub id: ChannelMessageId,
 45    pub body: String,
 46    pub timestamp: OffsetDateTime,
 47    pub sender: Arc<User>,
 48    pub nonce: u128,
 49    pub mentions: Vec<(Range<usize>, UserId)>,
 50}
 51
 52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 53pub enum ChannelMessageId {
 54    Saved(u64),
 55    Pending(usize),
 56}
 57
 58#[derive(Clone, Debug, Default)]
 59pub struct ChannelMessageSummary {
 60    max_id: ChannelMessageId,
 61    count: usize,
 62}
 63
 64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 65struct Count(usize);
 66
 67#[derive(Clone, Debug, PartialEq)]
 68pub enum ChannelChatEvent {
 69    MessagesUpdated {
 70        old_range: Range<usize>,
 71        new_count: usize,
 72    },
 73    NewMessage {
 74        channel_id: ChannelId,
 75        message_id: u64,
 76    },
 77}
 78
 79pub fn init(client: &Arc<Client>) {
 80    client.add_model_message_handler(ChannelChat::handle_message_sent);
 81    client.add_model_message_handler(ChannelChat::handle_message_removed);
 82}
 83
 84impl Entity for ChannelChat {
 85    type Event = ChannelChatEvent;
 86
 87    fn release(&mut self, _: &mut AppContext) {
 88        self.rpc
 89            .send(proto::LeaveChannelChat {
 90                channel_id: self.channel.id,
 91            })
 92            .log_err();
 93    }
 94}
 95
 96impl ChannelChat {
 97    pub async fn new(
 98        channel: Arc<Channel>,
 99        channel_store: ModelHandle<ChannelStore>,
100        user_store: ModelHandle<UserStore>,
101        client: Arc<Client>,
102        mut cx: AsyncAppContext,
103    ) -> Result<ModelHandle<Self>> {
104        let channel_id = channel.id;
105        let subscription = client.subscribe_to_entity(channel_id).unwrap();
106
107        let response = client
108            .request(proto::JoinChannelChat { channel_id })
109            .await?;
110        let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
111        let loaded_all_messages = response.done;
112
113        Ok(cx.add_model(|cx| {
114            let mut this = Self {
115                channel,
116                user_store,
117                channel_store,
118                rpc: client,
119                outgoing_messages_lock: Default::default(),
120                messages: Default::default(),
121                acknowledged_message_ids: Default::default(),
122                loaded_all_messages,
123                next_pending_message_id: 0,
124                last_acknowledged_id: None,
125                rng: StdRng::from_entropy(),
126                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
127            };
128            this.insert_messages(messages, cx);
129            this
130        }))
131    }
132
133    pub fn channel(&self) -> &Arc<Channel> {
134        &self.channel
135    }
136
137    pub fn client(&self) -> &Arc<Client> {
138        &self.rpc
139    }
140
141    pub fn send_message(
142        &mut self,
143        message: MessageParams,
144        cx: &mut ModelContext<Self>,
145    ) -> Result<Task<Result<u64>>> {
146        if message.text.is_empty() {
147            Err(anyhow!("message body can't be empty"))?;
148        }
149
150        let current_user = self
151            .user_store
152            .read(cx)
153            .current_user()
154            .ok_or_else(|| anyhow!("current_user is not present"))?;
155
156        let channel_id = self.channel.id;
157        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
158        let nonce = self.rng.gen();
159        self.insert_messages(
160            SumTree::from_item(
161                ChannelMessage {
162                    id: pending_id,
163                    body: message.text.clone(),
164                    sender: current_user,
165                    timestamp: OffsetDateTime::now_utc(),
166                    mentions: message.mentions.clone(),
167                    nonce,
168                },
169                &(),
170            ),
171            cx,
172        );
173        let user_store = self.user_store.clone();
174        let rpc = self.rpc.clone();
175        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
176        Ok(cx.spawn(|this, mut cx| async move {
177            let outgoing_message_guard = outgoing_messages_lock.lock().await;
178            let request = rpc.request(proto::SendChannelMessage {
179                channel_id,
180                body: message.text,
181                nonce: Some(nonce.into()),
182                mentions: mentions_to_proto(&message.mentions),
183            });
184            let response = request.await?;
185            drop(outgoing_message_guard);
186            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
187            let id = response.id;
188            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
189            this.update(&mut cx, |this, cx| {
190                this.insert_messages(SumTree::from_item(message, &()), cx);
191                Ok(id)
192            })
193        }))
194    }
195
196    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
197        let response = self.rpc.request(proto::RemoveChannelMessage {
198            channel_id: self.channel.id,
199            message_id: id,
200        });
201        cx.spawn(|this, mut cx| async move {
202            response.await?;
203
204            this.update(&mut cx, |this, cx| {
205                this.message_removed(id, cx);
206                Ok(())
207            })
208        })
209    }
210
211    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
212        if self.loaded_all_messages {
213            return None;
214        }
215
216        let rpc = self.rpc.clone();
217        let user_store = self.user_store.clone();
218        let channel_id = self.channel.id;
219        let before_message_id = self.first_loaded_message_id()?;
220        Some(cx.spawn(|this, mut cx| {
221            async move {
222                let response = rpc
223                    .request(proto::GetChannelMessages {
224                        channel_id,
225                        before_message_id,
226                    })
227                    .await?;
228                let loaded_all_messages = response.done;
229                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
230                this.update(&mut cx, |this, cx| {
231                    this.loaded_all_messages = loaded_all_messages;
232                    this.insert_messages(messages, cx);
233                });
234                anyhow::Ok(())
235            }
236            .log_err()
237        }))
238    }
239
240    pub fn first_loaded_message_id(&mut self) -> Option<u64> {
241        self.messages.first().and_then(|message| match message.id {
242            ChannelMessageId::Saved(id) => Some(id),
243            ChannelMessageId::Pending(_) => None,
244        })
245    }
246
247    /// Load all of the chat messages since a certain message id.
248    ///
249    /// For now, we always maintain a suffix of the channel's messages.
250    pub async fn load_history_since_message(
251        chat: ModelHandle<Self>,
252        message_id: u64,
253        mut cx: AsyncAppContext,
254    ) -> Option<usize> {
255        loop {
256            let step = chat.update(&mut cx, |chat, cx| {
257                if let Some(first_id) = chat.first_loaded_message_id() {
258                    if first_id <= message_id {
259                        let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
260                        let message_id = ChannelMessageId::Saved(message_id);
261                        cursor.seek(&message_id, Bias::Left, &());
262                        return ControlFlow::Break(
263                            if cursor
264                                .item()
265                                .map_or(false, |message| message.id == message_id)
266                            {
267                                Some(cursor.start().1 .0)
268                            } else {
269                                None
270                            },
271                        );
272                    }
273                }
274                ControlFlow::Continue(chat.load_more_messages(cx))
275            });
276            match step {
277                ControlFlow::Break(ix) => return ix,
278                ControlFlow::Continue(task) => task?.await?,
279            }
280        }
281    }
282
283    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
284        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
285            if self
286                .last_acknowledged_id
287                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
288            {
289                self.rpc
290                    .send(proto::AckChannelMessage {
291                        channel_id: self.channel.id,
292                        message_id: latest_message_id,
293                    })
294                    .ok();
295                self.last_acknowledged_id = Some(latest_message_id);
296                self.channel_store.update(cx, |store, cx| {
297                    store.acknowledge_message_id(self.channel.id, latest_message_id, cx);
298                });
299            }
300        }
301    }
302
303    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
304        let user_store = self.user_store.clone();
305        let rpc = self.rpc.clone();
306        let channel_id = self.channel.id;
307        cx.spawn(|this, mut cx| {
308            async move {
309                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
310                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
311                let loaded_all_messages = response.done;
312
313                let pending_messages = this.update(&mut cx, |this, cx| {
314                    if let Some((first_new_message, last_old_message)) =
315                        messages.first().zip(this.messages.last())
316                    {
317                        if first_new_message.id > last_old_message.id {
318                            let old_messages = mem::take(&mut this.messages);
319                            cx.emit(ChannelChatEvent::MessagesUpdated {
320                                old_range: 0..old_messages.summary().count,
321                                new_count: 0,
322                            });
323                            this.loaded_all_messages = loaded_all_messages;
324                        }
325                    }
326
327                    this.insert_messages(messages, cx);
328                    if loaded_all_messages {
329                        this.loaded_all_messages = loaded_all_messages;
330                    }
331
332                    this.pending_messages().cloned().collect::<Vec<_>>()
333                });
334
335                for pending_message in pending_messages {
336                    let request = rpc.request(proto::SendChannelMessage {
337                        channel_id,
338                        body: pending_message.body,
339                        mentions: mentions_to_proto(&pending_message.mentions),
340                        nonce: Some(pending_message.nonce.into()),
341                    });
342                    let response = request.await?;
343                    let message = ChannelMessage::from_proto(
344                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
345                        &user_store,
346                        &mut cx,
347                    )
348                    .await?;
349                    this.update(&mut cx, |this, cx| {
350                        this.insert_messages(SumTree::from_item(message, &()), cx);
351                    });
352                }
353
354                anyhow::Ok(())
355            }
356            .log_err()
357        })
358        .detach();
359    }
360
361    pub fn message_count(&self) -> usize {
362        self.messages.summary().count
363    }
364
365    pub fn messages(&self) -> &SumTree<ChannelMessage> {
366        &self.messages
367    }
368
369    pub fn message(&self, ix: usize) -> &ChannelMessage {
370        let mut cursor = self.messages.cursor::<Count>();
371        cursor.seek(&Count(ix), Bias::Right, &());
372        cursor.item().unwrap()
373    }
374
375    pub fn acknowledge_message(&mut self, id: u64) {
376        if self.acknowledged_message_ids.insert(id) {
377            self.rpc
378                .send(proto::AckChannelMessage {
379                    channel_id: self.channel.id,
380                    message_id: id,
381                })
382                .ok();
383        }
384    }
385
386    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
387        let mut cursor = self.messages.cursor::<Count>();
388        cursor.seek(&Count(range.start), Bias::Right, &());
389        cursor.take(range.len())
390    }
391
392    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
393        let mut cursor = self.messages.cursor::<ChannelMessageId>();
394        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
395        cursor
396    }
397
398    async fn handle_message_sent(
399        this: ModelHandle<Self>,
400        message: TypedEnvelope<proto::ChannelMessageSent>,
401        _: Arc<Client>,
402        mut cx: AsyncAppContext,
403    ) -> Result<()> {
404        let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
405        let message = message
406            .payload
407            .message
408            .ok_or_else(|| anyhow!("empty message"))?;
409        let message_id = message.id;
410
411        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
412        this.update(&mut cx, |this, cx| {
413            this.insert_messages(SumTree::from_item(message, &()), cx);
414            cx.emit(ChannelChatEvent::NewMessage {
415                channel_id: this.channel.id,
416                message_id,
417            })
418        });
419
420        Ok(())
421    }
422
423    async fn handle_message_removed(
424        this: ModelHandle<Self>,
425        message: TypedEnvelope<proto::RemoveChannelMessage>,
426        _: Arc<Client>,
427        mut cx: AsyncAppContext,
428    ) -> Result<()> {
429        this.update(&mut cx, |this, cx| {
430            this.message_removed(message.payload.message_id, cx)
431        });
432        Ok(())
433    }
434
435    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
436        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
437            let nonces = messages
438                .cursor::<()>()
439                .map(|m| m.nonce)
440                .collect::<HashSet<_>>();
441
442            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
443            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
444            let start_ix = old_cursor.start().1 .0;
445            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
446            let removed_count = removed_messages.summary().count;
447            let new_count = messages.summary().count;
448            let end_ix = start_ix + removed_count;
449
450            new_messages.append(messages, &());
451
452            let mut ranges = Vec::<Range<usize>>::new();
453            if new_messages.last().unwrap().is_pending() {
454                new_messages.append(old_cursor.suffix(&()), &());
455            } else {
456                new_messages.append(
457                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
458                    &(),
459                );
460
461                while let Some(message) = old_cursor.item() {
462                    let message_ix = old_cursor.start().1 .0;
463                    if nonces.contains(&message.nonce) {
464                        if ranges.last().map_or(false, |r| r.end == message_ix) {
465                            ranges.last_mut().unwrap().end += 1;
466                        } else {
467                            ranges.push(message_ix..message_ix + 1);
468                        }
469                    } else {
470                        new_messages.push(message.clone(), &());
471                    }
472                    old_cursor.next(&());
473                }
474            }
475
476            drop(old_cursor);
477            self.messages = new_messages;
478
479            for range in ranges.into_iter().rev() {
480                cx.emit(ChannelChatEvent::MessagesUpdated {
481                    old_range: range,
482                    new_count: 0,
483                });
484            }
485            cx.emit(ChannelChatEvent::MessagesUpdated {
486                old_range: start_ix..end_ix,
487                new_count,
488            });
489
490            cx.notify();
491        }
492    }
493
494    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
495        let mut cursor = self.messages.cursor::<ChannelMessageId>();
496        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
497        if let Some(item) = cursor.item() {
498            if item.id == ChannelMessageId::Saved(id) {
499                let ix = messages.summary().count;
500                cursor.next(&());
501                messages.append(cursor.suffix(&()), &());
502                drop(cursor);
503                self.messages = messages;
504                cx.emit(ChannelChatEvent::MessagesUpdated {
505                    old_range: ix..ix + 1,
506                    new_count: 0,
507                });
508            }
509        }
510    }
511}
512
513async fn messages_from_proto(
514    proto_messages: Vec<proto::ChannelMessage>,
515    user_store: &ModelHandle<UserStore>,
516    cx: &mut AsyncAppContext,
517) -> Result<SumTree<ChannelMessage>> {
518    let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
519    let mut result = SumTree::new();
520    result.extend(messages, &());
521    Ok(result)
522}
523
524impl ChannelMessage {
525    pub async fn from_proto(
526        message: proto::ChannelMessage,
527        user_store: &ModelHandle<UserStore>,
528        cx: &mut AsyncAppContext,
529    ) -> Result<Self> {
530        let sender = user_store
531            .update(cx, |user_store, cx| {
532                user_store.get_user(message.sender_id, cx)
533            })
534            .await?;
535        Ok(ChannelMessage {
536            id: ChannelMessageId::Saved(message.id),
537            body: message.body,
538            mentions: message
539                .mentions
540                .into_iter()
541                .filter_map(|mention| {
542                    let range = mention.range?;
543                    Some((range.start as usize..range.end as usize, mention.user_id))
544                })
545                .collect(),
546            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
547            sender,
548            nonce: message
549                .nonce
550                .ok_or_else(|| anyhow!("nonce is required"))?
551                .into(),
552        })
553    }
554
555    pub fn is_pending(&self) -> bool {
556        matches!(self.id, ChannelMessageId::Pending(_))
557    }
558
559    pub async fn from_proto_vec(
560        proto_messages: Vec<proto::ChannelMessage>,
561        user_store: &ModelHandle<UserStore>,
562        cx: &mut AsyncAppContext,
563    ) -> Result<Vec<Self>> {
564        let unique_user_ids = proto_messages
565            .iter()
566            .map(|m| m.sender_id)
567            .collect::<HashSet<_>>()
568            .into_iter()
569            .collect();
570        user_store
571            .update(cx, |user_store, cx| {
572                user_store.get_users(unique_user_ids, cx)
573            })
574            .await?;
575
576        let mut messages = Vec::with_capacity(proto_messages.len());
577        for message in proto_messages {
578            messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
579        }
580        Ok(messages)
581    }
582}
583
584pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
585    mentions
586        .iter()
587        .map(|(range, user_id)| proto::ChatMention {
588            range: Some(proto::Range {
589                start: range.start as u64,
590                end: range.end as u64,
591            }),
592            user_id: *user_id as u64,
593        })
594        .collect()
595}
596
597impl sum_tree::Item for ChannelMessage {
598    type Summary = ChannelMessageSummary;
599
600    fn summary(&self) -> Self::Summary {
601        ChannelMessageSummary {
602            max_id: self.id,
603            count: 1,
604        }
605    }
606}
607
608impl Default for ChannelMessageId {
609    fn default() -> Self {
610        Self::Saved(0)
611    }
612}
613
614impl sum_tree::Summary for ChannelMessageSummary {
615    type Context = ();
616
617    fn add_summary(&mut self, summary: &Self, _: &()) {
618        self.max_id = summary.max_id;
619        self.count += summary.count;
620    }
621}
622
623impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
624    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
625        debug_assert!(summary.max_id > *self);
626        *self = summary.max_id;
627    }
628}
629
630impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
631    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
632        self.0 += summary.count;
633    }
634}
635
636impl<'a> From<&'a str> for MessageParams {
637    fn from(value: &'a str) -> Self {
638        Self {
639            text: value.into(),
640            mentions: Vec::new(),
641        }
642    }
643}