channel_chat.rs

  1use crate::{Channel, ChannelId, ChannelStore};
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope,
  7};
  8use futures::lock::Mutex;
  9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use rand::prelude::*;
 11use std::{collections::HashSet, mem, ops::Range, sync::Arc};
 12use sum_tree::{Bias, SumTree};
 13use time::OffsetDateTime;
 14use util::{post_inc, ResultExt as _, TryFutureExt};
 15
 16pub struct ChannelChat {
 17    channel: Arc<Channel>,
 18    messages: SumTree<ChannelMessage>,
 19    channel_store: ModelHandle<ChannelStore>,
 20    loaded_all_messages: bool,
 21    last_acknowledged_id: Option<u64>,
 22    next_pending_message_id: usize,
 23    user_store: ModelHandle<UserStore>,
 24    rpc: Arc<Client>,
 25    outgoing_messages_lock: Arc<Mutex<()>>,
 26    rng: StdRng,
 27    _subscription: Subscription,
 28}
 29
 30#[derive(Clone, Debug)]
 31pub struct ChannelMessage {
 32    pub id: ChannelMessageId,
 33    pub body: String,
 34    pub timestamp: OffsetDateTime,
 35    pub sender: Arc<User>,
 36    pub nonce: u128,
 37}
 38
 39#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 40pub enum ChannelMessageId {
 41    Saved(u64),
 42    Pending(usize),
 43}
 44
 45#[derive(Clone, Debug, Default)]
 46pub struct ChannelMessageSummary {
 47    max_id: ChannelMessageId,
 48    count: usize,
 49}
 50
 51#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 52struct Count(usize);
 53
 54#[derive(Clone, Debug, PartialEq)]
 55pub enum ChannelChatEvent {
 56    MessagesUpdated {
 57        old_range: Range<usize>,
 58        new_count: usize,
 59    },
 60    NewMessage {
 61        channel_id: ChannelId,
 62        message_id: u64,
 63    },
 64}
 65
 66pub fn init(client: &Arc<Client>) {
 67    client.add_model_message_handler(ChannelChat::handle_message_sent);
 68    client.add_model_message_handler(ChannelChat::handle_message_removed);
 69}
 70
 71impl Entity for ChannelChat {
 72    type Event = ChannelChatEvent;
 73
 74    fn release(&mut self, _: &mut AppContext) {
 75        self.rpc
 76            .send(proto::LeaveChannelChat {
 77                channel_id: self.channel.id,
 78            })
 79            .log_err();
 80    }
 81}
 82
 83impl ChannelChat {
 84    pub async fn new(
 85        channel: Arc<Channel>,
 86        channel_store: ModelHandle<ChannelStore>,
 87        user_store: ModelHandle<UserStore>,
 88        client: Arc<Client>,
 89        mut cx: AsyncAppContext,
 90    ) -> Result<ModelHandle<Self>> {
 91        let channel_id = channel.id;
 92        let subscription = client.subscribe_to_entity(channel_id).unwrap();
 93
 94        let response = client
 95            .request(proto::JoinChannelChat { channel_id })
 96            .await?;
 97        let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
 98        let loaded_all_messages = response.done;
 99
100        Ok(cx.add_model(|cx| {
101            let mut this = Self {
102                channel,
103                user_store,
104                channel_store,
105                rpc: client,
106                outgoing_messages_lock: Default::default(),
107                messages: Default::default(),
108                loaded_all_messages,
109                next_pending_message_id: 0,
110                last_acknowledged_id: None,
111                rng: StdRng::from_entropy(),
112                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
113            };
114            this.insert_messages(messages, cx);
115            this
116        }))
117    }
118
119    pub fn channel(&self) -> &Arc<Channel> {
120        &self.channel
121    }
122
123    pub fn send_message(
124        &mut self,
125        body: String,
126        cx: &mut ModelContext<Self>,
127    ) -> Result<Task<Result<()>>> {
128        if body.is_empty() {
129            Err(anyhow!("message body can't be empty"))?;
130        }
131
132        let current_user = self
133            .user_store
134            .read(cx)
135            .current_user()
136            .ok_or_else(|| anyhow!("current_user is not present"))?;
137
138        let channel_id = self.channel.id;
139        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
140        let nonce = self.rng.gen();
141        self.insert_messages(
142            SumTree::from_item(
143                ChannelMessage {
144                    id: pending_id,
145                    body: body.clone(),
146                    sender: current_user,
147                    timestamp: OffsetDateTime::now_utc(),
148                    nonce,
149                },
150                &(),
151            ),
152            cx,
153        );
154        let user_store = self.user_store.clone();
155        let rpc = self.rpc.clone();
156        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
157        Ok(cx.spawn(|this, mut cx| async move {
158            let outgoing_message_guard = outgoing_messages_lock.lock().await;
159            let request = rpc.request(proto::SendChannelMessage {
160                channel_id,
161                body,
162                nonce: Some(nonce.into()),
163            });
164            let response = request.await?;
165            drop(outgoing_message_guard);
166            let message = ChannelMessage::from_proto(
167                response.message.ok_or_else(|| anyhow!("invalid message"))?,
168                &user_store,
169                &mut cx,
170            )
171            .await?;
172            this.update(&mut cx, |this, cx| {
173                this.insert_messages(SumTree::from_item(message, &()), cx);
174                Ok(())
175            })
176        }))
177    }
178
179    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
180        let response = self.rpc.request(proto::RemoveChannelMessage {
181            channel_id: self.channel.id,
182            message_id: id,
183        });
184        cx.spawn(|this, mut cx| async move {
185            response.await?;
186
187            this.update(&mut cx, |this, cx| {
188                this.message_removed(id, cx);
189                Ok(())
190            })
191        })
192    }
193
194    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
195        if !self.loaded_all_messages {
196            let rpc = self.rpc.clone();
197            let user_store = self.user_store.clone();
198            let channel_id = self.channel.id;
199            if let Some(before_message_id) =
200                self.messages.first().and_then(|message| match message.id {
201                    ChannelMessageId::Saved(id) => Some(id),
202                    ChannelMessageId::Pending(_) => None,
203                })
204            {
205                cx.spawn(|this, mut cx| {
206                    async move {
207                        let response = rpc
208                            .request(proto::GetChannelMessages {
209                                channel_id,
210                                before_message_id,
211                            })
212                            .await?;
213                        let loaded_all_messages = response.done;
214                        let messages =
215                            messages_from_proto(response.messages, &user_store, &mut cx).await?;
216                        this.update(&mut cx, |this, cx| {
217                            this.loaded_all_messages = loaded_all_messages;
218                            this.insert_messages(messages, cx);
219                        });
220                        anyhow::Ok(())
221                    }
222                    .log_err()
223                })
224                .detach();
225                return true;
226            }
227        }
228        false
229    }
230
231    pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
232        if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
233            if self
234                .last_acknowledged_id
235                .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
236            {
237                self.rpc
238                    .send(proto::AckChannelMessage {
239                        channel_id: self.channel.id,
240                        message_id: latest_message_id,
241                    })
242                    .ok();
243                self.last_acknowledged_id = Some(latest_message_id);
244                self.channel_store.update(cx, |store, cx| {
245                    store.acknowledge_message_id(self.channel.id, latest_message_id, cx);
246                });
247            }
248        }
249    }
250
251    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
252        let user_store = self.user_store.clone();
253        let rpc = self.rpc.clone();
254        let channel_id = self.channel.id;
255        cx.spawn(|this, mut cx| {
256            async move {
257                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
258                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
259                let loaded_all_messages = response.done;
260
261                let pending_messages = this.update(&mut cx, |this, cx| {
262                    if let Some((first_new_message, last_old_message)) =
263                        messages.first().zip(this.messages.last())
264                    {
265                        if first_new_message.id > last_old_message.id {
266                            let old_messages = mem::take(&mut this.messages);
267                            cx.emit(ChannelChatEvent::MessagesUpdated {
268                                old_range: 0..old_messages.summary().count,
269                                new_count: 0,
270                            });
271                            this.loaded_all_messages = loaded_all_messages;
272                        }
273                    }
274
275                    this.insert_messages(messages, cx);
276                    if loaded_all_messages {
277                        this.loaded_all_messages = loaded_all_messages;
278                    }
279
280                    this.pending_messages().cloned().collect::<Vec<_>>()
281                });
282
283                for pending_message in pending_messages {
284                    let request = rpc.request(proto::SendChannelMessage {
285                        channel_id,
286                        body: pending_message.body,
287                        nonce: Some(pending_message.nonce.into()),
288                    });
289                    let response = request.await?;
290                    let message = ChannelMessage::from_proto(
291                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
292                        &user_store,
293                        &mut cx,
294                    )
295                    .await?;
296                    this.update(&mut cx, |this, cx| {
297                        this.insert_messages(SumTree::from_item(message, &()), cx);
298                    });
299                }
300
301                anyhow::Ok(())
302            }
303            .log_err()
304        })
305        .detach();
306    }
307
308    pub fn message_count(&self) -> usize {
309        self.messages.summary().count
310    }
311
312    pub fn messages(&self) -> &SumTree<ChannelMessage> {
313        &self.messages
314    }
315
316    pub fn message(&self, ix: usize) -> &ChannelMessage {
317        let mut cursor = self.messages.cursor::<Count>();
318        cursor.seek(&Count(ix), Bias::Right, &());
319        cursor.item().unwrap()
320    }
321
322    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
323        let mut cursor = self.messages.cursor::<Count>();
324        cursor.seek(&Count(range.start), Bias::Right, &());
325        cursor.take(range.len())
326    }
327
328    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
329        let mut cursor = self.messages.cursor::<ChannelMessageId>();
330        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
331        cursor
332    }
333
334    async fn handle_message_sent(
335        this: ModelHandle<Self>,
336        message: TypedEnvelope<proto::ChannelMessageSent>,
337        _: Arc<Client>,
338        mut cx: AsyncAppContext,
339    ) -> Result<()> {
340        let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
341        let message = message
342            .payload
343            .message
344            .ok_or_else(|| anyhow!("empty message"))?;
345        let message_id = message.id;
346
347        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
348        this.update(&mut cx, |this, cx| {
349            this.insert_messages(SumTree::from_item(message, &()), cx);
350            cx.emit(ChannelChatEvent::NewMessage {
351                channel_id: this.channel.id,
352                message_id,
353            })
354        });
355
356        Ok(())
357    }
358
359    async fn handle_message_removed(
360        this: ModelHandle<Self>,
361        message: TypedEnvelope<proto::RemoveChannelMessage>,
362        _: Arc<Client>,
363        mut cx: AsyncAppContext,
364    ) -> Result<()> {
365        this.update(&mut cx, |this, cx| {
366            this.message_removed(message.payload.message_id, cx)
367        });
368        Ok(())
369    }
370
371    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
372        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
373            let nonces = messages
374                .cursor::<()>()
375                .map(|m| m.nonce)
376                .collect::<HashSet<_>>();
377
378            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
379            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
380            let start_ix = old_cursor.start().1 .0;
381            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
382            let removed_count = removed_messages.summary().count;
383            let new_count = messages.summary().count;
384            let end_ix = start_ix + removed_count;
385
386            new_messages.append(messages, &());
387
388            let mut ranges = Vec::<Range<usize>>::new();
389            if new_messages.last().unwrap().is_pending() {
390                new_messages.append(old_cursor.suffix(&()), &());
391            } else {
392                new_messages.append(
393                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
394                    &(),
395                );
396
397                while let Some(message) = old_cursor.item() {
398                    let message_ix = old_cursor.start().1 .0;
399                    if nonces.contains(&message.nonce) {
400                        if ranges.last().map_or(false, |r| r.end == message_ix) {
401                            ranges.last_mut().unwrap().end += 1;
402                        } else {
403                            ranges.push(message_ix..message_ix + 1);
404                        }
405                    } else {
406                        new_messages.push(message.clone(), &());
407                    }
408                    old_cursor.next(&());
409                }
410            }
411
412            drop(old_cursor);
413            self.messages = new_messages;
414
415            for range in ranges.into_iter().rev() {
416                cx.emit(ChannelChatEvent::MessagesUpdated {
417                    old_range: range,
418                    new_count: 0,
419                });
420            }
421            cx.emit(ChannelChatEvent::MessagesUpdated {
422                old_range: start_ix..end_ix,
423                new_count,
424            });
425
426            cx.notify();
427        }
428    }
429
430    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
431        let mut cursor = self.messages.cursor::<ChannelMessageId>();
432        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
433        if let Some(item) = cursor.item() {
434            if item.id == ChannelMessageId::Saved(id) {
435                let ix = messages.summary().count;
436                cursor.next(&());
437                messages.append(cursor.suffix(&()), &());
438                drop(cursor);
439                self.messages = messages;
440                cx.emit(ChannelChatEvent::MessagesUpdated {
441                    old_range: ix..ix + 1,
442                    new_count: 0,
443                });
444            }
445        }
446    }
447}
448
449async fn messages_from_proto(
450    proto_messages: Vec<proto::ChannelMessage>,
451    user_store: &ModelHandle<UserStore>,
452    cx: &mut AsyncAppContext,
453) -> Result<SumTree<ChannelMessage>> {
454    let unique_user_ids = proto_messages
455        .iter()
456        .map(|m| m.sender_id)
457        .collect::<HashSet<_>>()
458        .into_iter()
459        .collect();
460    user_store
461        .update(cx, |user_store, cx| {
462            user_store.get_users(unique_user_ids, cx)
463        })
464        .await?;
465
466    let mut messages = Vec::with_capacity(proto_messages.len());
467    for message in proto_messages {
468        messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
469    }
470    let mut result = SumTree::new();
471    result.extend(messages, &());
472    Ok(result)
473}
474
475impl ChannelMessage {
476    pub async fn from_proto(
477        message: proto::ChannelMessage,
478        user_store: &ModelHandle<UserStore>,
479        cx: &mut AsyncAppContext,
480    ) -> Result<Self> {
481        let sender = user_store
482            .update(cx, |user_store, cx| {
483                user_store.get_user(message.sender_id, cx)
484            })
485            .await?;
486        Ok(ChannelMessage {
487            id: ChannelMessageId::Saved(message.id),
488            body: message.body,
489            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
490            sender,
491            nonce: message
492                .nonce
493                .ok_or_else(|| anyhow!("nonce is required"))?
494                .into(),
495        })
496    }
497
498    pub fn is_pending(&self) -> bool {
499        matches!(self.id, ChannelMessageId::Pending(_))
500    }
501}
502
503impl sum_tree::Item for ChannelMessage {
504    type Summary = ChannelMessageSummary;
505
506    fn summary(&self) -> Self::Summary {
507        ChannelMessageSummary {
508            max_id: self.id,
509            count: 1,
510        }
511    }
512}
513
514impl Default for ChannelMessageId {
515    fn default() -> Self {
516        Self::Saved(0)
517    }
518}
519
520impl sum_tree::Summary for ChannelMessageSummary {
521    type Context = ();
522
523    fn add_summary(&mut self, summary: &Self, _: &()) {
524        self.max_id = summary.max_id;
525        self.count += summary.count;
526    }
527}
528
529impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
530    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
531        debug_assert!(summary.max_id > *self);
532        *self = summary.max_id;
533    }
534}
535
536impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
537    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
538        self.0 += summary.count;
539    }
540}