channel_chat.rs

  1use crate::Channel;
  2use anyhow::{anyhow, Result};
  3use client::{
  4    proto,
  5    user::{User, UserStore},
  6    Client, Subscription, TypedEnvelope,
  7};
  8use futures::lock::Mutex;
  9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 10use rand::prelude::*;
 11use std::{collections::HashSet, mem, ops::Range, sync::Arc};
 12use sum_tree::{Bias, SumTree};
 13use time::OffsetDateTime;
 14use util::{post_inc, ResultExt as _, TryFutureExt};
 15
 16pub struct ChannelChat {
 17    channel: Arc<Channel>,
 18    messages: SumTree<ChannelMessage>,
 19    loaded_all_messages: bool,
 20    next_pending_message_id: usize,
 21    user_store: ModelHandle<UserStore>,
 22    rpc: Arc<Client>,
 23    outgoing_messages_lock: Arc<Mutex<()>>,
 24    rng: StdRng,
 25    _subscription: Subscription,
 26}
 27
 28#[derive(Clone, Debug)]
 29pub struct ChannelMessage {
 30    pub id: ChannelMessageId,
 31    pub body: String,
 32    pub timestamp: OffsetDateTime,
 33    pub sender: Arc<User>,
 34    pub nonce: u128,
 35}
 36
 37#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 38pub enum ChannelMessageId {
 39    Saved(u64),
 40    Pending(usize),
 41}
 42
 43#[derive(Clone, Debug, Default)]
 44pub struct ChannelMessageSummary {
 45    max_id: ChannelMessageId,
 46    count: usize,
 47}
 48
 49#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 50struct Count(usize);
 51
 52#[derive(Clone, Debug, PartialEq)]
 53pub enum ChannelChatEvent {
 54    MessagesUpdated {
 55        old_range: Range<usize>,
 56        new_count: usize,
 57    },
 58}
 59
 60pub fn init(client: &Arc<Client>) {
 61    client.add_model_message_handler(ChannelChat::handle_message_sent);
 62    client.add_model_message_handler(ChannelChat::handle_message_removed);
 63}
 64
 65impl Entity for ChannelChat {
 66    type Event = ChannelChatEvent;
 67
 68    fn release(&mut self, _: &mut AppContext) {
 69        self.rpc
 70            .send(proto::LeaveChannelChat {
 71                channel_id: self.channel.id,
 72            })
 73            .log_err();
 74    }
 75}
 76
 77impl ChannelChat {
 78    pub async fn new(
 79        channel: Arc<Channel>,
 80        user_store: ModelHandle<UserStore>,
 81        client: Arc<Client>,
 82        mut cx: AsyncAppContext,
 83    ) -> Result<ModelHandle<Self>> {
 84        let channel_id = channel.id;
 85        let subscription = client.subscribe_to_entity(channel_id).unwrap();
 86
 87        let response = client
 88            .request(proto::JoinChannelChat { channel_id })
 89            .await?;
 90        let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
 91        let loaded_all_messages = response.done;
 92
 93        Ok(cx.add_model(|cx| {
 94            let mut this = Self {
 95                channel,
 96                user_store,
 97                rpc: client,
 98                outgoing_messages_lock: Default::default(),
 99                messages: Default::default(),
100                loaded_all_messages,
101                next_pending_message_id: 0,
102                rng: StdRng::from_entropy(),
103                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
104            };
105            this.insert_messages(messages, cx);
106            this
107        }))
108    }
109
110    pub fn channel(&self) -> &Arc<Channel> {
111        &self.channel
112    }
113
114    pub fn send_message(
115        &mut self,
116        body: String,
117        cx: &mut ModelContext<Self>,
118    ) -> Result<Task<Result<()>>> {
119        if body.is_empty() {
120            Err(anyhow!("message body can't be empty"))?;
121        }
122
123        let current_user = self
124            .user_store
125            .read(cx)
126            .current_user()
127            .ok_or_else(|| anyhow!("current_user is not present"))?;
128
129        let channel_id = self.channel.id;
130        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
131        let nonce = self.rng.gen();
132        self.insert_messages(
133            SumTree::from_item(
134                ChannelMessage {
135                    id: pending_id,
136                    body: body.clone(),
137                    sender: current_user,
138                    timestamp: OffsetDateTime::now_utc(),
139                    nonce,
140                },
141                &(),
142            ),
143            cx,
144        );
145        let user_store = self.user_store.clone();
146        let rpc = self.rpc.clone();
147        let outgoing_messages_lock = self.outgoing_messages_lock.clone();
148        Ok(cx.spawn(|this, mut cx| async move {
149            let outgoing_message_guard = outgoing_messages_lock.lock().await;
150            let request = rpc.request(proto::SendChannelMessage {
151                channel_id,
152                body,
153                nonce: Some(nonce.into()),
154            });
155            let response = request.await?;
156            drop(outgoing_message_guard);
157            let message = ChannelMessage::from_proto(
158                response.message.ok_or_else(|| anyhow!("invalid message"))?,
159                &user_store,
160                &mut cx,
161            )
162            .await?;
163            this.update(&mut cx, |this, cx| {
164                this.insert_messages(SumTree::from_item(message, &()), cx);
165                Ok(())
166            })
167        }))
168    }
169
170    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
171        let response = self.rpc.request(proto::RemoveChannelMessage {
172            channel_id: self.channel.id,
173            message_id: id,
174        });
175        cx.spawn(|this, mut cx| async move {
176            response.await?;
177
178            this.update(&mut cx, |this, cx| {
179                this.message_removed(id, cx);
180                Ok(())
181            })
182        })
183    }
184
185    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
186        if !self.loaded_all_messages {
187            let rpc = self.rpc.clone();
188            let user_store = self.user_store.clone();
189            let channel_id = self.channel.id;
190            if let Some(before_message_id) =
191                self.messages.first().and_then(|message| match message.id {
192                    ChannelMessageId::Saved(id) => Some(id),
193                    ChannelMessageId::Pending(_) => None,
194                })
195            {
196                cx.spawn(|this, mut cx| {
197                    async move {
198                        let response = rpc
199                            .request(proto::GetChannelMessages {
200                                channel_id,
201                                before_message_id,
202                            })
203                            .await?;
204                        let loaded_all_messages = response.done;
205                        let messages =
206                            messages_from_proto(response.messages, &user_store, &mut cx).await?;
207                        this.update(&mut cx, |this, cx| {
208                            this.loaded_all_messages = loaded_all_messages;
209                            this.insert_messages(messages, cx);
210                        });
211                        anyhow::Ok(())
212                    }
213                    .log_err()
214                })
215                .detach();
216                return true;
217            }
218        }
219        false
220    }
221
222    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
223        let user_store = self.user_store.clone();
224        let rpc = self.rpc.clone();
225        let channel_id = self.channel.id;
226        cx.spawn(|this, mut cx| {
227            async move {
228                let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
229                let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
230                let loaded_all_messages = response.done;
231
232                let pending_messages = this.update(&mut cx, |this, cx| {
233                    if let Some((first_new_message, last_old_message)) =
234                        messages.first().zip(this.messages.last())
235                    {
236                        if first_new_message.id > last_old_message.id {
237                            let old_messages = mem::take(&mut this.messages);
238                            cx.emit(ChannelChatEvent::MessagesUpdated {
239                                old_range: 0..old_messages.summary().count,
240                                new_count: 0,
241                            });
242                            this.loaded_all_messages = loaded_all_messages;
243                        }
244                    }
245
246                    this.insert_messages(messages, cx);
247                    if loaded_all_messages {
248                        this.loaded_all_messages = loaded_all_messages;
249                    }
250
251                    this.pending_messages().cloned().collect::<Vec<_>>()
252                });
253
254                for pending_message in pending_messages {
255                    let request = rpc.request(proto::SendChannelMessage {
256                        channel_id,
257                        body: pending_message.body,
258                        nonce: Some(pending_message.nonce.into()),
259                    });
260                    let response = request.await?;
261                    let message = ChannelMessage::from_proto(
262                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
263                        &user_store,
264                        &mut cx,
265                    )
266                    .await?;
267                    this.update(&mut cx, |this, cx| {
268                        this.insert_messages(SumTree::from_item(message, &()), cx);
269                    });
270                }
271
272                anyhow::Ok(())
273            }
274            .log_err()
275        })
276        .detach();
277    }
278
279    pub fn message_count(&self) -> usize {
280        self.messages.summary().count
281    }
282
283    pub fn messages(&self) -> &SumTree<ChannelMessage> {
284        &self.messages
285    }
286
287    pub fn message(&self, ix: usize) -> &ChannelMessage {
288        let mut cursor = self.messages.cursor::<Count>();
289        cursor.seek(&Count(ix), Bias::Right, &());
290        cursor.item().unwrap()
291    }
292
293    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
294        let mut cursor = self.messages.cursor::<Count>();
295        cursor.seek(&Count(range.start), Bias::Right, &());
296        cursor.take(range.len())
297    }
298
299    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
300        let mut cursor = self.messages.cursor::<ChannelMessageId>();
301        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
302        cursor
303    }
304
305    async fn handle_message_sent(
306        this: ModelHandle<Self>,
307        message: TypedEnvelope<proto::ChannelMessageSent>,
308        _: Arc<Client>,
309        mut cx: AsyncAppContext,
310    ) -> Result<()> {
311        let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
312        let message = message
313            .payload
314            .message
315            .ok_or_else(|| anyhow!("empty message"))?;
316
317        let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
318        this.update(&mut cx, |this, cx| {
319            this.insert_messages(SumTree::from_item(message, &()), cx)
320        });
321
322        Ok(())
323    }
324
325    async fn handle_message_removed(
326        this: ModelHandle<Self>,
327        message: TypedEnvelope<proto::RemoveChannelMessage>,
328        _: Arc<Client>,
329        mut cx: AsyncAppContext,
330    ) -> Result<()> {
331        this.update(&mut cx, |this, cx| {
332            this.message_removed(message.payload.message_id, cx)
333        });
334        Ok(())
335    }
336
337    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
338        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
339            let nonces = messages
340                .cursor::<()>()
341                .map(|m| m.nonce)
342                .collect::<HashSet<_>>();
343
344            let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
345            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
346            let start_ix = old_cursor.start().1 .0;
347            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
348            let removed_count = removed_messages.summary().count;
349            let new_count = messages.summary().count;
350            let end_ix = start_ix + removed_count;
351
352            new_messages.append(messages, &());
353
354            let mut ranges = Vec::<Range<usize>>::new();
355            if new_messages.last().unwrap().is_pending() {
356                new_messages.append(old_cursor.suffix(&()), &());
357            } else {
358                new_messages.append(
359                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
360                    &(),
361                );
362
363                while let Some(message) = old_cursor.item() {
364                    let message_ix = old_cursor.start().1 .0;
365                    if nonces.contains(&message.nonce) {
366                        if ranges.last().map_or(false, |r| r.end == message_ix) {
367                            ranges.last_mut().unwrap().end += 1;
368                        } else {
369                            ranges.push(message_ix..message_ix + 1);
370                        }
371                    } else {
372                        new_messages.push(message.clone(), &());
373                    }
374                    old_cursor.next(&());
375                }
376            }
377
378            drop(old_cursor);
379            self.messages = new_messages;
380
381            for range in ranges.into_iter().rev() {
382                cx.emit(ChannelChatEvent::MessagesUpdated {
383                    old_range: range,
384                    new_count: 0,
385                });
386            }
387            cx.emit(ChannelChatEvent::MessagesUpdated {
388                old_range: start_ix..end_ix,
389                new_count,
390            });
391            cx.notify();
392        }
393    }
394
395    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
396        let mut cursor = self.messages.cursor::<ChannelMessageId>();
397        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
398        if let Some(item) = cursor.item() {
399            if item.id == ChannelMessageId::Saved(id) {
400                let ix = messages.summary().count;
401                cursor.next(&());
402                messages.append(cursor.suffix(&()), &());
403                drop(cursor);
404                self.messages = messages;
405                cx.emit(ChannelChatEvent::MessagesUpdated {
406                    old_range: ix..ix + 1,
407                    new_count: 0,
408                });
409            }
410        }
411    }
412}
413
414async fn messages_from_proto(
415    proto_messages: Vec<proto::ChannelMessage>,
416    user_store: &ModelHandle<UserStore>,
417    cx: &mut AsyncAppContext,
418) -> Result<SumTree<ChannelMessage>> {
419    let unique_user_ids = proto_messages
420        .iter()
421        .map(|m| m.sender_id)
422        .collect::<HashSet<_>>()
423        .into_iter()
424        .collect();
425    user_store
426        .update(cx, |user_store, cx| {
427            user_store.get_users(unique_user_ids, cx)
428        })
429        .await?;
430
431    let mut messages = Vec::with_capacity(proto_messages.len());
432    for message in proto_messages {
433        messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
434    }
435    let mut result = SumTree::new();
436    result.extend(messages, &());
437    Ok(result)
438}
439
440impl ChannelMessage {
441    pub async fn from_proto(
442        message: proto::ChannelMessage,
443        user_store: &ModelHandle<UserStore>,
444        cx: &mut AsyncAppContext,
445    ) -> Result<Self> {
446        let sender = user_store
447            .update(cx, |user_store, cx| {
448                user_store.get_user(message.sender_id, cx)
449            })
450            .await?;
451        Ok(ChannelMessage {
452            id: ChannelMessageId::Saved(message.id),
453            body: message.body,
454            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
455            sender,
456            nonce: message
457                .nonce
458                .ok_or_else(|| anyhow!("nonce is required"))?
459                .into(),
460        })
461    }
462
463    pub fn is_pending(&self) -> bool {
464        matches!(self.id, ChannelMessageId::Pending(_))
465    }
466}
467
468impl sum_tree::Item for ChannelMessage {
469    type Summary = ChannelMessageSummary;
470
471    fn summary(&self) -> Self::Summary {
472        ChannelMessageSummary {
473            max_id: self.id,
474            count: 1,
475        }
476    }
477}
478
479impl Default for ChannelMessageId {
480    fn default() -> Self {
481        Self::Saved(0)
482    }
483}
484
485impl sum_tree::Summary for ChannelMessageSummary {
486    type Context = ();
487
488    fn add_summary(&mut self, summary: &Self, _: &()) {
489        self.max_id = summary.max_id;
490        self.count += summary.count;
491    }
492}
493
494impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
495    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
496        debug_assert!(summary.max_id > *self);
497        *self = summary.max_id;
498    }
499}
500
501impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
502    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
503        self.0 += summary.count;
504    }
505}