channel_chat.rs

  1use crate::{Channel, ChannelId, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope, UserId,
  7};
  8use futures::lock::Mutex;
  9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use rand::prelude::*;
 11use std::{
 12    collections::HashSet,
 13    mem,
 14    ops::{ControlFlow, Range},
 15    sync::Arc,
 16};
 17use sum_tree::{Bias, SumTree};
 18use time::OffsetDateTime;
 19use util::{post_inc, ResultExt as _, TryFutureExt};
 20
 21pub struct ChannelChat {
 22    pub channel_id: ChannelId,
 23    messages: SumTree<ChannelMessage>,
 24    acknowledged_message_ids: HashSet<u64>,
 25    channel_store: ModelHandle<ChannelStore>,
 26    loaded_all_messages: bool,
 27    last_acknowledged_id: Option<u64>,
 28    next_pending_message_id: usize,
 29    user_store: ModelHandle<UserStore>,
 30    rpc: Arc<Client>,
 31    outgoing_messages_lock: Arc<Mutex<()>>,
 32    rng: StdRng,
 33    _subscription: Subscription,
 34}
 35
 36#[derive(Debug, PartialEq, Eq)]
 37pub struct MessageParams {
 38    pub text: String,
 39    pub mentions: Vec<(Range<usize>, UserId)>,
 40}
 41
 42#[derive(Clone, Debug)]
 43pub struct ChannelMessage {
 44    pub id: ChannelMessageId,
 45    pub body: String,
 46    pub timestamp: OffsetDateTime,
 47    pub sender: Arc<User>,
 48    pub nonce: u128,
 49    pub mentions: Vec<(Range<usize>, UserId)>,
 50}
 51
 52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 53pub enum ChannelMessageId {
 54    Saved(u64),
 55    Pending(usize),
 56}
 57
 58#[derive(Clone, Debug, Default)]
 59pub struct ChannelMessageSummary {
 60    max_id: ChannelMessageId,
 61    count: usize,
 62}
 63
 64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 65struct Count(usize);
 66
 67#[derive(Clone, Debug, PartialEq)]
 68pub enum ChannelChatEvent {
 69    MessagesUpdated {
 70        old_range: Range<usize>,
 71        new_count: usize,
 72    },
 73    NewMessage {
 74        channel_id: ChannelId,
 75        message_id: u64,
 76    },
 77}
 78
 79pub fn init(client: &Arc<Client>) {
 80    client.add_model_message_handler(ChannelChat::handle_message_sent);
 81    client.add_model_message_handler(ChannelChat::handle_message_removed);
 82}
 83
 84impl Entity for ChannelChat {
 85    type Event = ChannelChatEvent;
 86
 87    fn release(&mut self, _: &mut AppContext) {
 88        self.rpc
 89            .send(proto::LeaveChannelChat {
 90                channel_id: self.channel_id,
 91            })
 92            .log_err();
 93    }
 94}
 95
 96impl ChannelChat {
 97    pub async fn new(
 98        channel: Arc<Channel>,
 99        channel_store: ModelHandle<ChannelStore>,
100        user_store: ModelHandle<UserStore>,
101        client: Arc<Client>,
102        mut cx: AsyncAppContext,
103    ) -> Result<ModelHandle<Self>> {
104        let channel_id = channel.id;
105        let subscription = client.subscribe_to_entity(channel_id).unwrap();
106
107        let response = client
108            .request(proto::JoinChannelChat { channel_id })
109            .await?;
110        let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
111        let loaded_all_messages = response.done;
112
113        Ok(cx.add_model(|cx| {
114            let mut this = Self {
115                channel_id: channel.id,
116                user_store,
117                channel_store,
118                rpc: client,
119                outgoing_messages_lock: Default::default(),
120                messages: Default::default(),
121                acknowledged_message_ids: Default::default(),
122                loaded_all_messages,
123                next_pending_message_id: 0,
124                last_acknowledged_id: None,
125                rng: StdRng::from_entropy(),
126                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
127            };
128            this.insert_messages(messages, cx);
129            this
130        }))
131    }
132
133    pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
134        self.channel_store
135            .read(cx)
136            .channel_for_id(self.channel_id)
137            .cloned()
138    }
139
140    pub fn client(&self) -> &Arc<Client> {
141        &self.rpc
142    }
143
144    pub fn send_message(
145        &mut self,
146        message: MessageParams,
147        cx: &mut ModelContext<Self>,
148    ) -> Result<Task<Result<u64>>> {
149        if message.text.is_empty() {
150            Err(anyhow!("message body can't be empty"))?;
151        }
152
153        let current_user = self
154            .user_store
155            .read(cx)
156            .current_user()
157            .ok_or_else(|| anyhow!("current_user is not present"))?;
158
159        let channel_id = self.channel_id;
160        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
161        let nonce = self.rng.gen();
162        self.insert_messages(
163            SumTree::from_item(
164                ChannelMessage {
165                    id: pending_id,
166                    body: message.text.clone(),
167                    sender: current_user,
168                    timestamp: OffsetDateTime::now_utc(),
169                    mentions: message.mentions.clone(),
170                    nonce,
171                },
172                &(),
173            ),
174            cx,
175        );
176        let user_store = self.user_store.clone();
177        let rpc = self.rpc.clone();
178        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
179        Ok(cx.spawn(|this, mut cx| async move {
180            let outgoing_message_guard = outgoing_messages_lock.lock().await;
181            let request = rpc.request(proto::SendChannelMessage {
182                channel_id,
183                body: message.text,
184                nonce: Some(nonce.into()),
185                mentions: mentions_to_proto(&message.mentions),
186            });
187            let response = request.await?;
188            drop(outgoing_message_guard);
189            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
190            let id = response.id;
191            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
192            this.update(&mut cx, |this, cx| {
193                this.insert_messages(SumTree::from_item(message, &()), cx);
194                Ok(id)
195            })
196        }))
197    }
198
199    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
200        let response = self.rpc.request(proto::RemoveChannelMessage {
201            channel_id: self.channel_id,
202            message_id: id,
203        });
204        cx.spawn(|this, mut cx| async move {
205            response.await?;
206
207            this.update(&mut cx, |this, cx| {
208                this.message_removed(id, cx);
209                Ok(())
210            })
211        })
212    }
213
214    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
215        if self.loaded_all_messages {
216            return None;
217        }
218
219        let rpc = self.rpc.clone();
220        let user_store = self.user_store.clone();
221        let channel_id = self.channel_id;
222        let before_message_id = self.first_loaded_message_id()?;
223        Some(cx.spawn(|this, mut cx| {
224            async move {
225                let response = rpc
226                    .request(proto::GetChannelMessages {
227                        channel_id,
228                        before_message_id,
229                    })
230                    .await?;
231                let loaded_all_messages = response.done;
232                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
233                this.update(&mut cx, |this, cx| {
234                    this.loaded_all_messages = loaded_all_messages;
235                    this.insert_messages(messages, cx);
236                });
237                anyhow::Ok(())
238            }
239            .log_err()
240        }))
241    }
242
243    pub fn first_loaded_message_id(&mut self) -> Option<u64> {
244        self.messages.first().and_then(|message| match message.id {
245            ChannelMessageId::Saved(id) => Some(id),
246            ChannelMessageId::Pending(_) => None,
247        })
248    }
249
250    /// Load all of the chat messages since a certain message id.
251    ///
252    /// For now, we always maintain a suffix of the channel's messages.
253    pub async fn load_history_since_message(
254        chat: ModelHandle<Self>,
255        message_id: u64,
256        mut cx: AsyncAppContext,
257    ) -> Option<usize> {
258        loop {
259            let step = chat.update(&mut cx, |chat, cx| {
260                if let Some(first_id) = chat.first_loaded_message_id() {
261                    if first_id <= message_id {
262                        let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
263                        let message_id = ChannelMessageId::Saved(message_id);
264                        cursor.seek(&message_id, Bias::Left, &());
265                        return ControlFlow::Break(
266                            if cursor
267                                .item()
268                                .map_or(false, |message| message.id == message_id)
269                            {
270                                Some(cursor.start().1 .0)
271                            } else {
272                                None
273                            },
274                        );
275                    }
276                }
277                ControlFlow::Continue(chat.load_more_messages(cx))
278            });
279            match step {
280                ControlFlow::Break(ix) => return ix,
281                ControlFlow::Continue(task) => task?.await?,
282            }
283        }
284    }
285
286    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
287        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
288            if self
289                .last_acknowledged_id
290                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
291            {
292                self.rpc
293                    .send(proto::AckChannelMessage {
294                        channel_id: self.channel_id,
295                        message_id: latest_message_id,
296                    })
297                    .ok();
298                self.last_acknowledged_id = Some(latest_message_id);
299                self.channel_store.update(cx, |store, cx| {
300                    store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
301                });
302            }
303        }
304    }
305
306    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
307        let user_store = self.user_store.clone();
308        let rpc = self.rpc.clone();
309        let channel_id = self.channel_id;
310        cx.spawn(|this, mut cx| {
311            async move {
312                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
313                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
314                let loaded_all_messages = response.done;
315
316                let pending_messages = this.update(&mut cx, |this, cx| {
317                    if let Some((first_new_message, last_old_message)) =
318                        messages.first().zip(this.messages.last())
319                    {
320                        if first_new_message.id > last_old_message.id {
321                            let old_messages = mem::take(&mut this.messages);
322                            cx.emit(ChannelChatEvent::MessagesUpdated {
323                                old_range: 0..old_messages.summary().count,
324                                new_count: 0,
325                            });
326                            this.loaded_all_messages = loaded_all_messages;
327                        }
328                    }
329
330                    this.insert_messages(messages, cx);
331                    if loaded_all_messages {
332                        this.loaded_all_messages = loaded_all_messages;
333                    }
334
335                    this.pending_messages().cloned().collect::<Vec<_>>()
336                });
337
338                for pending_message in pending_messages {
339                    let request = rpc.request(proto::SendChannelMessage {
340                        channel_id,
341                        body: pending_message.body,
342                        mentions: mentions_to_proto(&pending_message.mentions),
343                        nonce: Some(pending_message.nonce.into()),
344                    });
345                    let response = request.await?;
346                    let message = ChannelMessage::from_proto(
347                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
348                        &user_store,
349                        &mut cx,
350                    )
351                    .await?;
352                    this.update(&mut cx, |this, cx| {
353                        this.insert_messages(SumTree::from_item(message, &()), cx);
354                    });
355                }
356
357                anyhow::Ok(())
358            }
359            .log_err()
360        })
361        .detach();
362    }
363
364    pub fn message_count(&self) -> usize {
365        self.messages.summary().count
366    }
367
368    pub fn messages(&self) -> &SumTree<ChannelMessage> {
369        &self.messages
370    }
371
372    pub fn message(&self, ix: usize) -> &ChannelMessage {
373        let mut cursor = self.messages.cursor::<Count>();
374        cursor.seek(&Count(ix), Bias::Right, &());
375        cursor.item().unwrap()
376    }
377
378    pub fn acknowledge_message(&mut self, id: u64) {
379        if self.acknowledged_message_ids.insert(id) {
380            self.rpc
381                .send(proto::AckChannelMessage {
382                    channel_id: self.channel_id,
383                    message_id: id,
384                })
385                .ok();
386        }
387    }
388
389    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
390        let mut cursor = self.messages.cursor::<Count>();
391        cursor.seek(&Count(range.start), Bias::Right, &());
392        cursor.take(range.len())
393    }
394
395    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
396        let mut cursor = self.messages.cursor::<ChannelMessageId>();
397        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
398        cursor
399    }
400
401    async fn handle_message_sent(
402        this: ModelHandle<Self>,
403        message: TypedEnvelope<proto::ChannelMessageSent>,
404        _: Arc<Client>,
405        mut cx: AsyncAppContext,
406    ) -> Result<()> {
407        let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
408        let message = message
409            .payload
410            .message
411            .ok_or_else(|| anyhow!("empty message"))?;
412        let message_id = message.id;
413
414        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
415        this.update(&mut cx, |this, cx| {
416            this.insert_messages(SumTree::from_item(message, &()), cx);
417            cx.emit(ChannelChatEvent::NewMessage {
418                channel_id: this.channel_id,
419                message_id,
420            })
421        });
422
423        Ok(())
424    }
425
426    async fn handle_message_removed(
427        this: ModelHandle<Self>,
428        message: TypedEnvelope<proto::RemoveChannelMessage>,
429        _: Arc<Client>,
430        mut cx: AsyncAppContext,
431    ) -> Result<()> {
432        this.update(&mut cx, |this, cx| {
433            this.message_removed(message.payload.message_id, cx)
434        });
435        Ok(())
436    }
437
438    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
439        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
440            let nonces = messages
441                .cursor::<()>()
442                .map(|m| m.nonce)
443                .collect::<HashSet<_>>();
444
445            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
446            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
447            let start_ix = old_cursor.start().1 .0;
448            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
449            let removed_count = removed_messages.summary().count;
450            let new_count = messages.summary().count;
451            let end_ix = start_ix + removed_count;
452
453            new_messages.append(messages, &());
454
455            let mut ranges = Vec::<Range<usize>>::new();
456            if new_messages.last().unwrap().is_pending() {
457                new_messages.append(old_cursor.suffix(&()), &());
458            } else {
459                new_messages.append(
460                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
461                    &(),
462                );
463
464                while let Some(message) = old_cursor.item() {
465                    let message_ix = old_cursor.start().1 .0;
466                    if nonces.contains(&message.nonce) {
467                        if ranges.last().map_or(false, |r| r.end == message_ix) {
468                            ranges.last_mut().unwrap().end += 1;
469                        } else {
470                            ranges.push(message_ix..message_ix + 1);
471                        }
472                    } else {
473                        new_messages.push(message.clone(), &());
474                    }
475                    old_cursor.next(&());
476                }
477            }
478
479            drop(old_cursor);
480            self.messages = new_messages;
481
482            for range in ranges.into_iter().rev() {
483                cx.emit(ChannelChatEvent::MessagesUpdated {
484                    old_range: range,
485                    new_count: 0,
486                });
487            }
488            cx.emit(ChannelChatEvent::MessagesUpdated {
489                old_range: start_ix..end_ix,
490                new_count,
491            });
492
493            cx.notify();
494        }
495    }
496
497    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
498        let mut cursor = self.messages.cursor::<ChannelMessageId>();
499        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
500        if let Some(item) = cursor.item() {
501            if item.id == ChannelMessageId::Saved(id) {
502                let ix = messages.summary().count;
503                cursor.next(&());
504                messages.append(cursor.suffix(&()), &());
505                drop(cursor);
506                self.messages = messages;
507                cx.emit(ChannelChatEvent::MessagesUpdated {
508                    old_range: ix..ix + 1,
509                    new_count: 0,
510                });
511            }
512        }
513    }
514}
515
516async fn messages_from_proto(
517    proto_messages: Vec<proto::ChannelMessage>,
518    user_store: &ModelHandle<UserStore>,
519    cx: &mut AsyncAppContext,
520) -> Result<SumTree<ChannelMessage>> {
521    let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
522    let mut result = SumTree::new();
523    result.extend(messages, &());
524    Ok(result)
525}
526
527impl ChannelMessage {
528    pub async fn from_proto(
529        message: proto::ChannelMessage,
530        user_store: &ModelHandle<UserStore>,
531        cx: &mut AsyncAppContext,
532    ) -> Result<Self> {
533        let sender = user_store
534            .update(cx, |user_store, cx| {
535                user_store.get_user(message.sender_id, cx)
536            })
537            .await?;
538        Ok(ChannelMessage {
539            id: ChannelMessageId::Saved(message.id),
540            body: message.body,
541            mentions: message
542                .mentions
543                .into_iter()
544                .filter_map(|mention| {
545                    let range = mention.range?;
546                    Some((range.start as usize..range.end as usize, mention.user_id))
547                })
548                .collect(),
549            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
550            sender,
551            nonce: message
552                .nonce
553                .ok_or_else(|| anyhow!("nonce is required"))?
554                .into(),
555        })
556    }
557
558    pub fn is_pending(&self) -> bool {
559        matches!(self.id, ChannelMessageId::Pending(_))
560    }
561
562    pub async fn from_proto_vec(
563        proto_messages: Vec<proto::ChannelMessage>,
564        user_store: &ModelHandle<UserStore>,
565        cx: &mut AsyncAppContext,
566    ) -> Result<Vec<Self>> {
567        let unique_user_ids = proto_messages
568            .iter()
569            .map(|m| m.sender_id)
570            .collect::<HashSet<_>>()
571            .into_iter()
572            .collect();
573        user_store
574            .update(cx, |user_store, cx| {
575                user_store.get_users(unique_user_ids, cx)
576            })
577            .await?;
578
579        let mut messages = Vec::with_capacity(proto_messages.len());
580        for message in proto_messages {
581            messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
582        }
583        Ok(messages)
584    }
585}
586
587pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
588    mentions
589        .iter()
590        .map(|(range, user_id)| proto::ChatMention {
591            range: Some(proto::Range {
592                start: range.start as u64,
593                end: range.end as u64,
594            }),
595            user_id: *user_id as u64,
596        })
597        .collect()
598}
599
600impl sum_tree::Item for ChannelMessage {
601    type Summary = ChannelMessageSummary;
602
603    fn summary(&self) -> Self::Summary {
604        ChannelMessageSummary {
605            max_id: self.id,
606            count: 1,
607        }
608    }
609}
610
611impl Default for ChannelMessageId {
612    fn default() -> Self {
613        Self::Saved(0)
614    }
615}
616
617impl sum_tree::Summary for ChannelMessageSummary {
618    type Context = ();
619
620    fn add_summary(&mut self, summary: &Self, _: &()) {
621        self.max_id = summary.max_id;
622        self.count += summary.count;
623    }
624}
625
626impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
627    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
628        debug_assert!(summary.max_id > *self);
629        *self = summary.max_id;
630    }
631}
632
633impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
634    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
635        self.0 += summary.count;
636    }
637}
638
639impl<'a> From<&'a str> for MessageParams {
640    fn from(value: &'a str) -> Self {
641        Self {
642            text: value.into(),
643            mentions: Vec::new(),
644        }
645    }
646}