channel_chat.rs

  1use crate::{Channel, ChannelId, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope, UserId,
  7};
  8use futures::lock::Mutex;
  9use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
 10use rand::prelude::*;
 11use std::{
 12    collections::HashSet,
 13    mem,
 14    ops::{ControlFlow, Range},
 15    sync::Arc,
 16};
 17use sum_tree::{Bias, SumTree};
 18use time::OffsetDateTime;
 19use util::{post_inc, ResultExt as _, TryFutureExt};
 20
 21pub struct ChannelChat {
 22    pub channel_id: ChannelId,
 23    messages: SumTree<ChannelMessage>,
 24    acknowledged_message_ids: HashSet<u64>,
 25    channel_store: Model<ChannelStore>,
 26    loaded_all_messages: bool,
 27    last_acknowledged_id: Option<u64>,
 28    next_pending_message_id: usize,
 29    user_store: Model<UserStore>,
 30    rpc: Arc<Client>,
 31    outgoing_messages_lock: Arc<Mutex<()>>,
 32    rng: StdRng,
 33    _subscription: Subscription,
 34}
 35
 36#[derive(Debug, PartialEq, Eq)]
 37pub struct MessageParams {
 38    pub text: String,
 39    pub mentions: Vec<(Range<usize>, UserId)>,
 40}
 41
 42#[derive(Clone, Debug)]
 43pub struct ChannelMessage {
 44    pub id: ChannelMessageId,
 45    pub body: String,
 46    pub timestamp: OffsetDateTime,
 47    pub sender: Arc<User>,
 48    pub nonce: u128,
 49    pub mentions: Vec<(Range<usize>, UserId)>,
 50}
 51
 52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 53pub enum ChannelMessageId {
 54    Saved(u64),
 55    Pending(usize),
 56}
 57
 58#[derive(Clone, Debug, Default)]
 59pub struct ChannelMessageSummary {
 60    max_id: ChannelMessageId,
 61    count: usize,
 62}
 63
 64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 65struct Count(usize);
 66
 67#[derive(Clone, Debug, PartialEq)]
 68pub enum ChannelChatEvent {
 69    MessagesUpdated {
 70        old_range: Range<usize>,
 71        new_count: usize,
 72    },
 73    NewMessage {
 74        channel_id: ChannelId,
 75        message_id: u64,
 76    },
 77}
 78
 79impl EventEmitter<ChannelChatEvent> for ChannelChat {}
 80pub fn init(client: &Arc<Client>) {
 81    client.add_model_message_handler(ChannelChat::handle_message_sent);
 82    client.add_model_message_handler(ChannelChat::handle_message_removed);
 83}
 84
 85impl ChannelChat {
 86    pub async fn new(
 87        channel: Arc<Channel>,
 88        channel_store: Model<ChannelStore>,
 89        user_store: Model<UserStore>,
 90        client: Arc<Client>,
 91        mut cx: AsyncAppContext,
 92    ) -> Result<Model<Self>> {
 93        let channel_id = channel.id;
 94        let subscription = client.subscribe_to_entity(channel_id).unwrap();
 95
 96        let response = client
 97            .request(proto::JoinChannelChat { channel_id })
 98            .await?;
 99        let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
100        let loaded_all_messages = response.done;
101
102        Ok(cx.new_model(|cx| {
103            cx.on_release(Self::release).detach();
104            let mut this = Self {
105                channel_id: channel.id,
106                user_store,
107                channel_store,
108                rpc: client,
109                outgoing_messages_lock: Default::default(),
110                messages: Default::default(),
111                acknowledged_message_ids: Default::default(),
112                loaded_all_messages,
113                next_pending_message_id: 0,
114                last_acknowledged_id: None,
115                rng: StdRng::from_entropy(),
116                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
117            };
118            this.insert_messages(messages, cx);
119            this
120        })?)
121    }
122
123    fn release(&mut self, _: &mut AppContext) {
124        self.rpc
125            .send(proto::LeaveChannelChat {
126                channel_id: self.channel_id,
127            })
128            .log_err();
129    }
130
131    pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
132        self.channel_store
133            .read(cx)
134            .channel_for_id(self.channel_id)
135            .cloned()
136    }
137
138    pub fn client(&self) -> &Arc<Client> {
139        &self.rpc
140    }
141
142    pub fn send_message(
143        &mut self,
144        message: MessageParams,
145        cx: &mut ModelContext<Self>,
146    ) -> Result<Task<Result<u64>>> {
147        if message.text.trim().is_empty() {
148            Err(anyhow!("message body can't be empty"))?;
149        }
150
151        let current_user = self
152            .user_store
153            .read(cx)
154            .current_user()
155            .ok_or_else(|| anyhow!("current_user is not present"))?;
156
157        let channel_id = self.channel_id;
158        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
159        let nonce = self.rng.gen();
160        self.insert_messages(
161            SumTree::from_item(
162                ChannelMessage {
163                    id: pending_id,
164                    body: message.text.clone(),
165                    sender: current_user,
166                    timestamp: OffsetDateTime::now_utc(),
167                    mentions: message.mentions.clone(),
168                    nonce,
169                },
170                &(),
171            ),
172            cx,
173        );
174        let user_store = self.user_store.clone();
175        let rpc = self.rpc.clone();
176        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
177
178        // todo - handle messages that fail to send (e.g. >1024 chars)
179        Ok(cx.spawn(move |this, mut cx| async move {
180            let outgoing_message_guard = outgoing_messages_lock.lock().await;
181            let request = rpc.request(proto::SendChannelMessage {
182                channel_id,
183                body: message.text,
184                nonce: Some(nonce.into()),
185                mentions: mentions_to_proto(&message.mentions),
186            });
187            let response = request.await?;
188            drop(outgoing_message_guard);
189            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
190            let id = response.id;
191            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
192            this.update(&mut cx, |this, cx| {
193                this.insert_messages(SumTree::from_item(message, &()), cx);
194            })?;
195            Ok(id)
196        }))
197    }
198
199    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
200        let response = self.rpc.request(proto::RemoveChannelMessage {
201            channel_id: self.channel_id,
202            message_id: id,
203        });
204        cx.spawn(move |this, mut cx| async move {
205            response.await?;
206            this.update(&mut cx, |this, cx| {
207                this.message_removed(id, cx);
208            })?;
209            Ok(())
210        })
211    }
212
213    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
214        if self.loaded_all_messages {
215            return None;
216        }
217
218        let rpc = self.rpc.clone();
219        let user_store = self.user_store.clone();
220        let channel_id = self.channel_id;
221        let before_message_id = self.first_loaded_message_id()?;
222        Some(cx.spawn(move |this, mut cx| {
223            async move {
224                let response = rpc
225                    .request(proto::GetChannelMessages {
226                        channel_id,
227                        before_message_id,
228                    })
229                    .await?;
230                let loaded_all_messages = response.done;
231                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
232                this.update(&mut cx, |this, cx| {
233                    this.loaded_all_messages = loaded_all_messages;
234                    this.insert_messages(messages, cx);
235                })?;
236                anyhow::Ok(())
237            }
238            .log_err()
239        }))
240    }
241
242    pub fn first_loaded_message_id(&mut self) -> Option<u64> {
243        self.messages.first().and_then(|message| match message.id {
244            ChannelMessageId::Saved(id) => Some(id),
245            ChannelMessageId::Pending(_) => None,
246        })
247    }
248
249    /// Load all of the chat messages since a certain message id.
250    ///
251    /// For now, we always maintain a suffix of the channel's messages.
252    pub async fn load_history_since_message(
253        chat: Model<Self>,
254        message_id: u64,
255        mut cx: AsyncAppContext,
256    ) -> Option<usize> {
257        loop {
258            let step = chat
259                .update(&mut cx, |chat, cx| {
260                    if let Some(first_id) = chat.first_loaded_message_id() {
261                        if first_id <= message_id {
262                            let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
263                            let message_id = ChannelMessageId::Saved(message_id);
264                            cursor.seek(&message_id, Bias::Left, &());
265                            return ControlFlow::Break(
266                                if cursor
267                                    .item()
268                                    .map_or(false, |message| message.id == message_id)
269                                {
270                                    Some(cursor.start().1 .0)
271                                } else {
272                                    None
273                                },
274                            );
275                        }
276                    }
277                    ControlFlow::Continue(chat.load_more_messages(cx))
278                })
279                .log_err()?;
280            match step {
281                ControlFlow::Break(ix) => return ix,
282                ControlFlow::Continue(task) => task?.await?,
283            }
284        }
285    }
286
287    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
288        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
289            if self
290                .last_acknowledged_id
291                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
292            {
293                self.rpc
294                    .send(proto::AckChannelMessage {
295                        channel_id: self.channel_id,
296                        message_id: latest_message_id,
297                    })
298                    .ok();
299                self.last_acknowledged_id = Some(latest_message_id);
300                self.channel_store.update(cx, |store, cx| {
301                    store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
302                });
303            }
304        }
305    }
306
307    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
308        let user_store = self.user_store.clone();
309        let rpc = self.rpc.clone();
310        let channel_id = self.channel_id;
311        cx.spawn(move |this, mut cx| {
312            async move {
313                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
314                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
315                let loaded_all_messages = response.done;
316
317                let pending_messages = this.update(&mut cx, |this, cx| {
318                    if let Some((first_new_message, last_old_message)) =
319                        messages.first().zip(this.messages.last())
320                    {
321                        if first_new_message.id > last_old_message.id {
322                            let old_messages = mem::take(&mut this.messages);
323                            cx.emit(ChannelChatEvent::MessagesUpdated {
324                                old_range: 0..old_messages.summary().count,
325                                new_count: 0,
326                            });
327                            this.loaded_all_messages = loaded_all_messages;
328                        }
329                    }
330
331                    this.insert_messages(messages, cx);
332                    if loaded_all_messages {
333                        this.loaded_all_messages = loaded_all_messages;
334                    }
335
336                    this.pending_messages().cloned().collect::<Vec<_>>()
337                })?;
338
339                for pending_message in pending_messages {
340                    let request = rpc.request(proto::SendChannelMessage {
341                        channel_id,
342                        body: pending_message.body,
343                        mentions: mentions_to_proto(&pending_message.mentions),
344                        nonce: Some(pending_message.nonce.into()),
345                    });
346                    let response = request.await?;
347                    let message = ChannelMessage::from_proto(
348                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
349                        &user_store,
350                        &mut cx,
351                    )
352                    .await?;
353                    this.update(&mut cx, |this, cx| {
354                        this.insert_messages(SumTree::from_item(message, &()), cx);
355                    })?;
356                }
357
358                anyhow::Ok(())
359            }
360            .log_err()
361        })
362        .detach();
363    }
364
365    pub fn message_count(&self) -> usize {
366        self.messages.summary().count
367    }
368
369    pub fn messages(&self) -> &SumTree<ChannelMessage> {
370        &self.messages
371    }
372
373    pub fn message(&self, ix: usize) -> &ChannelMessage {
374        let mut cursor = self.messages.cursor::<Count>();
375        cursor.seek(&Count(ix), Bias::Right, &());
376        cursor.item().unwrap()
377    }
378
379    pub fn acknowledge_message(&mut self, id: u64) {
380        if self.acknowledged_message_ids.insert(id) {
381            self.rpc
382                .send(proto::AckChannelMessage {
383                    channel_id: self.channel_id,
384                    message_id: id,
385                })
386                .ok();
387        }
388    }
389
390    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
391        let mut cursor = self.messages.cursor::<Count>();
392        cursor.seek(&Count(range.start), Bias::Right, &());
393        cursor.take(range.len())
394    }
395
396    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
397        let mut cursor = self.messages.cursor::<ChannelMessageId>();
398        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
399        cursor
400    }
401
402    async fn handle_message_sent(
403        this: Model<Self>,
404        message: TypedEnvelope<proto::ChannelMessageSent>,
405        _: Arc<Client>,
406        mut cx: AsyncAppContext,
407    ) -> Result<()> {
408        let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
409        let message = message
410            .payload
411            .message
412            .ok_or_else(|| anyhow!("empty message"))?;
413        let message_id = message.id;
414
415        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
416        this.update(&mut cx, |this, cx| {
417            this.insert_messages(SumTree::from_item(message, &()), cx);
418            cx.emit(ChannelChatEvent::NewMessage {
419                channel_id: this.channel_id,
420                message_id,
421            })
422        })?;
423
424        Ok(())
425    }
426
427    async fn handle_message_removed(
428        this: Model<Self>,
429        message: TypedEnvelope<proto::RemoveChannelMessage>,
430        _: Arc<Client>,
431        mut cx: AsyncAppContext,
432    ) -> Result<()> {
433        this.update(&mut cx, |this, cx| {
434            this.message_removed(message.payload.message_id, cx)
435        })?;
436        Ok(())
437    }
438
439    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
440        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
441            let nonces = messages
442                .cursor::<()>()
443                .map(|m| m.nonce)
444                .collect::<HashSet<_>>();
445
446            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
447            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
448            let start_ix = old_cursor.start().1 .0;
449            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
450            let removed_count = removed_messages.summary().count;
451            let new_count = messages.summary().count;
452            let end_ix = start_ix + removed_count;
453
454            new_messages.append(messages, &());
455
456            let mut ranges = Vec::<Range<usize>>::new();
457            if new_messages.last().unwrap().is_pending() {
458                new_messages.append(old_cursor.suffix(&()), &());
459            } else {
460                new_messages.append(
461                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
462                    &(),
463                );
464
465                while let Some(message) = old_cursor.item() {
466                    let message_ix = old_cursor.start().1 .0;
467                    if nonces.contains(&message.nonce) {
468                        if ranges.last().map_or(false, |r| r.end == message_ix) {
469                            ranges.last_mut().unwrap().end += 1;
470                        } else {
471                            ranges.push(message_ix..message_ix + 1);
472                        }
473                    } else {
474                        new_messages.push(message.clone(), &());
475                    }
476                    old_cursor.next(&());
477                }
478            }
479
480            drop(old_cursor);
481            self.messages = new_messages;
482
483            for range in ranges.into_iter().rev() {
484                cx.emit(ChannelChatEvent::MessagesUpdated {
485                    old_range: range,
486                    new_count: 0,
487                });
488            }
489            cx.emit(ChannelChatEvent::MessagesUpdated {
490                old_range: start_ix..end_ix,
491                new_count,
492            });
493
494            cx.notify();
495        }
496    }
497
498    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
499        let mut cursor = self.messages.cursor::<ChannelMessageId>();
500        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
501        if let Some(item) = cursor.item() {
502            if item.id == ChannelMessageId::Saved(id) {
503                let ix = messages.summary().count;
504                cursor.next(&());
505                messages.append(cursor.suffix(&()), &());
506                drop(cursor);
507                self.messages = messages;
508                cx.emit(ChannelChatEvent::MessagesUpdated {
509                    old_range: ix..ix + 1,
510                    new_count: 0,
511                });
512            }
513        }
514    }
515}
516
517async fn messages_from_proto(
518    proto_messages: Vec<proto::ChannelMessage>,
519    user_store: &Model<UserStore>,
520    cx: &mut AsyncAppContext,
521) -> Result<SumTree<ChannelMessage>> {
522    let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
523    let mut result = SumTree::new();
524    result.extend(messages, &());
525    Ok(result)
526}
527
528impl ChannelMessage {
529    pub async fn from_proto(
530        message: proto::ChannelMessage,
531        user_store: &Model<UserStore>,
532        cx: &mut AsyncAppContext,
533    ) -> Result<Self> {
534        let sender = user_store
535            .update(cx, |user_store, cx| {
536                user_store.get_user(message.sender_id, cx)
537            })?
538            .await?;
539        Ok(ChannelMessage {
540            id: ChannelMessageId::Saved(message.id),
541            body: message.body,
542            mentions: message
543                .mentions
544                .into_iter()
545                .filter_map(|mention| {
546                    let range = mention.range?;
547                    Some((range.start as usize..range.end as usize, mention.user_id))
548                })
549                .collect(),
550            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
551            sender,
552            nonce: message
553                .nonce
554                .ok_or_else(|| anyhow!("nonce is required"))?
555                .into(),
556        })
557    }
558
559    pub fn is_pending(&self) -> bool {
560        matches!(self.id, ChannelMessageId::Pending(_))
561    }
562
563    pub async fn from_proto_vec(
564        proto_messages: Vec<proto::ChannelMessage>,
565        user_store: &Model<UserStore>,
566        cx: &mut AsyncAppContext,
567    ) -> Result<Vec<Self>> {
568        let unique_user_ids = proto_messages
569            .iter()
570            .map(|m| m.sender_id)
571            .collect::<HashSet<_>>()
572            .into_iter()
573            .collect();
574        user_store
575            .update(cx, |user_store, cx| {
576                user_store.get_users(unique_user_ids, cx)
577            })?
578            .await?;
579
580        let mut messages = Vec::with_capacity(proto_messages.len());
581        for message in proto_messages {
582            messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
583        }
584        Ok(messages)
585    }
586}
587
588pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
589    mentions
590        .iter()
591        .map(|(range, user_id)| proto::ChatMention {
592            range: Some(proto::Range {
593                start: range.start as u64,
594                end: range.end as u64,
595            }),
596            user_id: *user_id as u64,
597        })
598        .collect()
599}
600
601impl sum_tree::Item for ChannelMessage {
602    type Summary = ChannelMessageSummary;
603
604    fn summary(&self) -> Self::Summary {
605        ChannelMessageSummary {
606            max_id: self.id,
607            count: 1,
608        }
609    }
610}
611
612impl Default for ChannelMessageId {
613    fn default() -> Self {
614        Self::Saved(0)
615    }
616}
617
618impl sum_tree::Summary for ChannelMessageSummary {
619    type Context = ();
620
621    fn add_summary(&mut self, summary: &Self, _: &()) {
622        self.max_id = summary.max_id;
623        self.count += summary.count;
624    }
625}
626
627impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
628    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
629        debug_assert!(summary.max_id > *self);
630        *self = summary.max_id;
631    }
632}
633
634impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
635    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
636        self.0 += summary.count;
637    }
638}
639
640impl<'a> From<&'a str> for MessageParams {
641    fn from(value: &'a str) -> Self {
642        Self {
643            text: value.into(),
644            mentions: Vec::new(),
645        }
646    }
647}