channel.rs

  1use crate::{
  2    rpc::{self, Client},
  3    user::{User, UserStore},
  4    util::{post_inc, TryFutureExt},
  5};
  6use anyhow::{anyhow, Context, Result};
  7use gpui::{
  8    sum_tree::{self, Bias, SumTree},
  9    Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
 10};
 11use postage::prelude::Stream;
 12use std::{
 13    collections::{HashMap, HashSet},
 14    mem,
 15    ops::Range,
 16    sync::Arc,
 17};
 18use time::OffsetDateTime;
 19use zrpc::{
 20    proto::{self, ChannelMessageSent},
 21    TypedEnvelope,
 22};
 23
 24pub struct ChannelList {
 25    available_channels: Option<Vec<ChannelDetails>>,
 26    channels: HashMap<u64, WeakModelHandle<Channel>>,
 27    rpc: Arc<Client>,
 28    user_store: Arc<UserStore>,
 29    _task: Task<Option<()>>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq)]
 33pub struct ChannelDetails {
 34    pub id: u64,
 35    pub name: String,
 36}
 37
 38pub struct Channel {
 39    details: ChannelDetails,
 40    messages: SumTree<ChannelMessage>,
 41    loaded_all_messages: bool,
 42    next_pending_message_id: usize,
 43    user_store: Arc<UserStore>,
 44    rpc: Arc<Client>,
 45    _subscription: rpc::Subscription,
 46}
 47
 48#[derive(Clone, Debug)]
 49pub struct ChannelMessage {
 50    pub id: ChannelMessageId,
 51    pub body: String,
 52    pub timestamp: OffsetDateTime,
 53    pub sender: Arc<User>,
 54}
 55
 56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 57pub enum ChannelMessageId {
 58    Saved(u64),
 59    Pending(usize),
 60}
 61
 62#[derive(Clone, Debug, Default)]
 63pub struct ChannelMessageSummary {
 64    max_id: ChannelMessageId,
 65    count: usize,
 66}
 67
 68#[derive(Copy, Clone, Debug, Default)]
 69struct Count(usize);
 70
 71pub enum ChannelListEvent {}
 72
 73#[derive(Clone, Debug, PartialEq)]
 74pub enum ChannelEvent {
 75    MessagesUpdated {
 76        old_range: Range<usize>,
 77        new_count: usize,
 78    },
 79}
 80
 81impl Entity for ChannelList {
 82    type Event = ChannelListEvent;
 83}
 84
 85impl ChannelList {
 86    pub fn new(
 87        user_store: Arc<UserStore>,
 88        rpc: Arc<rpc::Client>,
 89        cx: &mut ModelContext<Self>,
 90    ) -> Self {
 91        let _task = cx.spawn_weak(|this, mut cx| {
 92            let rpc = rpc.clone();
 93            async move {
 94                let mut status = rpc.status();
 95                while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
 96                    match status {
 97                        rpc::Status::Connected { .. } => {
 98                            let response = rpc
 99                                .request(proto::GetChannels {})
100                                .await
101                                .context("failed to fetch available channels")?;
102                            this.update(&mut cx, |this, cx| {
103                                this.available_channels =
104                                    Some(response.channels.into_iter().map(Into::into).collect());
105
106                                let mut to_remove = Vec::new();
107                                for (channel_id, channel) in &this.channels {
108                                    if let Some(channel) = channel.upgrade(cx) {
109                                        channel.update(cx, |channel, cx| channel.rejoin(cx))
110                                    } else {
111                                        to_remove.push(*channel_id);
112                                    }
113                                }
114
115                                for channel_id in to_remove {
116                                    this.channels.remove(&channel_id);
117                                }
118                                cx.notify();
119                            });
120                        }
121                        rpc::Status::SignedOut { .. } => {
122                            this.update(&mut cx, |this, cx| {
123                                this.available_channels = None;
124                                this.channels.clear();
125                                cx.notify();
126                            });
127                        }
128                        _ => {}
129                    }
130                }
131                Ok(())
132            }
133            .log_err()
134        });
135
136        Self {
137            available_channels: None,
138            channels: Default::default(),
139            user_store,
140            rpc,
141            _task,
142        }
143    }
144
145    pub fn available_channels(&self) -> Option<&[ChannelDetails]> {
146        self.available_channels.as_ref().map(Vec::as_slice)
147    }
148
149    pub fn get_channel(
150        &mut self,
151        id: u64,
152        cx: &mut MutableAppContext,
153    ) -> Option<ModelHandle<Channel>> {
154        if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) {
155            return Some(channel);
156        }
157
158        let channels = self.available_channels.as_ref()?;
159        let details = channels.iter().find(|details| details.id == id)?.clone();
160        let channel =
161            cx.add_model(|cx| Channel::new(details, self.user_store.clone(), self.rpc.clone(), cx));
162        self.channels.insert(id, channel.downgrade());
163        Some(channel)
164    }
165}
166
167impl Entity for Channel {
168    type Event = ChannelEvent;
169
170    fn release(&mut self, cx: &mut MutableAppContext) {
171        let rpc = self.rpc.clone();
172        let channel_id = self.details.id;
173        cx.foreground()
174            .spawn(async move {
175                if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await {
176                    log::error!("error leaving channel: {}", error);
177                };
178            })
179            .detach()
180    }
181}
182
183impl Channel {
184    pub fn new(
185        details: ChannelDetails,
186        user_store: Arc<UserStore>,
187        rpc: Arc<Client>,
188        cx: &mut ModelContext<Self>,
189    ) -> Self {
190        let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent);
191
192        {
193            let user_store = user_store.clone();
194            let rpc = rpc.clone();
195            let channel_id = details.id;
196            cx.spawn(|channel, mut cx| {
197                async move {
198                    let response = rpc.request(proto::JoinChannel { channel_id }).await?;
199                    let messages = messages_from_proto(response.messages, &user_store).await?;
200                    let loaded_all_messages = response.done;
201
202                    channel.update(&mut cx, |channel, cx| {
203                        channel.insert_messages(messages, cx);
204                        channel.loaded_all_messages = loaded_all_messages;
205                    });
206
207                    Ok(())
208                }
209                .log_err()
210            })
211            .detach();
212        }
213
214        Self {
215            details,
216            user_store,
217            rpc,
218            messages: Default::default(),
219            loaded_all_messages: false,
220            next_pending_message_id: 0,
221            _subscription,
222        }
223    }
224
225    pub fn name(&self) -> &str {
226        &self.details.name
227    }
228
229    pub fn send_message(
230        &mut self,
231        body: String,
232        cx: &mut ModelContext<Self>,
233    ) -> Result<Task<Result<()>>> {
234        if body.is_empty() {
235            Err(anyhow!("message body can't be empty"))?;
236        }
237
238        let current_user = self
239            .user_store
240            .current_user()
241            .ok_or_else(|| anyhow!("current_user is not present"))?;
242
243        let channel_id = self.details.id;
244        let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
245        self.insert_messages(
246            SumTree::from_item(
247                ChannelMessage {
248                    id: pending_id,
249                    body: body.clone(),
250                    sender: current_user,
251                    timestamp: OffsetDateTime::now_utc(),
252                },
253                &(),
254            ),
255            cx,
256        );
257        let user_store = self.user_store.clone();
258        let rpc = self.rpc.clone();
259        Ok(cx.spawn(|this, mut cx| async move {
260            let request = rpc.request(proto::SendChannelMessage { channel_id, body });
261            let response = request.await?;
262            let message = ChannelMessage::from_proto(
263                response.message.ok_or_else(|| anyhow!("invalid message"))?,
264                &user_store,
265            )
266            .await?;
267            this.update(&mut cx, |this, cx| {
268                this.remove_message(pending_id, cx);
269                this.insert_messages(SumTree::from_item(message, &()), cx);
270                Ok(())
271            })
272        }))
273    }
274
275    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
276        if !self.loaded_all_messages {
277            let rpc = self.rpc.clone();
278            let user_store = self.user_store.clone();
279            let channel_id = self.details.id;
280            if let Some(before_message_id) =
281                self.messages.first().and_then(|message| match message.id {
282                    ChannelMessageId::Saved(id) => Some(id),
283                    ChannelMessageId::Pending(_) => None,
284                })
285            {
286                cx.spawn(|this, mut cx| {
287                    async move {
288                        let response = rpc
289                            .request(proto::GetChannelMessages {
290                                channel_id,
291                                before_message_id,
292                            })
293                            .await?;
294                        let loaded_all_messages = response.done;
295                        let messages = messages_from_proto(response.messages, &user_store).await?;
296                        this.update(&mut cx, |this, cx| {
297                            this.loaded_all_messages = loaded_all_messages;
298                            this.insert_messages(messages, cx);
299                        });
300                        Ok(())
301                    }
302                    .log_err()
303                })
304                .detach();
305                return true;
306            }
307        }
308        false
309    }
310
311    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
312        let user_store = self.user_store.clone();
313        let rpc = self.rpc.clone();
314        let channel_id = self.details.id;
315        cx.spawn(|channel, mut cx| {
316            async move {
317                let response = rpc.request(proto::JoinChannel { channel_id }).await?;
318                let messages = messages_from_proto(response.messages, &user_store).await?;
319                let loaded_all_messages = response.done;
320
321                channel.update(&mut cx, |channel, cx| {
322                    if let Some((first_new_message, last_old_message)) =
323                        messages.first().zip(channel.messages.last())
324                    {
325                        if first_new_message.id > last_old_message.id {
326                            let old_messages = mem::take(&mut channel.messages);
327                            cx.emit(ChannelEvent::MessagesUpdated {
328                                old_range: 0..old_messages.summary().count,
329                                new_count: 0,
330                            });
331                            channel.loaded_all_messages = loaded_all_messages;
332                        }
333                    }
334
335                    channel.insert_messages(messages, cx);
336                    if loaded_all_messages {
337                        channel.loaded_all_messages = loaded_all_messages;
338                    }
339                });
340
341                Ok(())
342            }
343            .log_err()
344        })
345        .detach();
346    }
347
348    pub fn message_count(&self) -> usize {
349        self.messages.summary().count
350    }
351
352    pub fn messages(&self) -> &SumTree<ChannelMessage> {
353        &self.messages
354    }
355
356    pub fn message(&self, ix: usize) -> &ChannelMessage {
357        let mut cursor = self.messages.cursor::<Count, ()>();
358        cursor.seek(&Count(ix), Bias::Right, &());
359        cursor.item().unwrap()
360    }
361
362    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
363        let mut cursor = self.messages.cursor::<Count, ()>();
364        cursor.seek(&Count(range.start), Bias::Right, &());
365        cursor.take(range.len())
366    }
367
368    fn handle_message_sent(
369        &mut self,
370        message: TypedEnvelope<ChannelMessageSent>,
371        _: Arc<rpc::Client>,
372        cx: &mut ModelContext<Self>,
373    ) -> Result<()> {
374        let user_store = self.user_store.clone();
375        let message = message
376            .payload
377            .message
378            .ok_or_else(|| anyhow!("empty message"))?;
379
380        cx.spawn(|this, mut cx| {
381            async move {
382                let message = ChannelMessage::from_proto(message, &user_store).await?;
383                this.update(&mut cx, |this, cx| {
384                    this.insert_messages(SumTree::from_item(message, &()), cx)
385                });
386                Ok(())
387            }
388            .log_err()
389        })
390        .detach();
391        Ok(())
392    }
393
394    fn remove_message(&mut self, message_id: ChannelMessageId, cx: &mut ModelContext<Self>) {
395        let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
396        let mut new_messages = old_cursor.slice(&message_id, Bias::Left, &());
397        let start_ix = old_cursor.sum_start().0;
398        let removed_messages = old_cursor.slice(&message_id, Bias::Right, &());
399        let removed_count = removed_messages.summary().count;
400        new_messages.push_tree(old_cursor.suffix(&()), &());
401
402        drop(old_cursor);
403        self.messages = new_messages;
404
405        if removed_count > 0 {
406            let end_ix = start_ix + removed_count;
407            cx.emit(ChannelEvent::MessagesUpdated {
408                old_range: start_ix..end_ix,
409                new_count: 0,
410            });
411            cx.notify();
412        }
413    }
414
415    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
416        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
417            let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
418            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
419            let start_ix = old_cursor.sum_start().0;
420            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
421            let removed_count = removed_messages.summary().count;
422            let new_count = messages.summary().count;
423            let end_ix = start_ix + removed_count;
424
425            new_messages.push_tree(messages, &());
426            new_messages.push_tree(old_cursor.suffix(&()), &());
427            drop(old_cursor);
428            self.messages = new_messages;
429
430            cx.emit(ChannelEvent::MessagesUpdated {
431                old_range: start_ix..end_ix,
432                new_count,
433            });
434            cx.notify();
435        }
436    }
437}
438
439async fn messages_from_proto(
440    proto_messages: Vec<proto::ChannelMessage>,
441    user_store: &UserStore,
442) -> Result<SumTree<ChannelMessage>> {
443    let unique_user_ids = proto_messages
444        .iter()
445        .map(|m| m.sender_id)
446        .collect::<HashSet<_>>()
447        .into_iter()
448        .collect();
449    user_store.load_users(unique_user_ids).await?;
450
451    let mut messages = Vec::with_capacity(proto_messages.len());
452    for message in proto_messages {
453        messages.push(ChannelMessage::from_proto(message, &user_store).await?);
454    }
455    let mut result = SumTree::new();
456    result.extend(messages, &());
457    Ok(result)
458}
459
460impl From<proto::Channel> for ChannelDetails {
461    fn from(message: proto::Channel) -> Self {
462        Self {
463            id: message.id,
464            name: message.name,
465        }
466    }
467}
468
469impl ChannelMessage {
470    pub async fn from_proto(
471        message: proto::ChannelMessage,
472        user_store: &UserStore,
473    ) -> Result<Self> {
474        let sender = user_store.fetch_user(message.sender_id).await?;
475        Ok(ChannelMessage {
476            id: ChannelMessageId::Saved(message.id),
477            body: message.body,
478            timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
479            sender,
480        })
481    }
482
483    pub fn is_pending(&self) -> bool {
484        matches!(self.id, ChannelMessageId::Pending(_))
485    }
486}
487
488impl sum_tree::Item for ChannelMessage {
489    type Summary = ChannelMessageSummary;
490
491    fn summary(&self) -> Self::Summary {
492        ChannelMessageSummary {
493            max_id: self.id,
494            count: 1,
495        }
496    }
497}
498
499impl Default for ChannelMessageId {
500    fn default() -> Self {
501        Self::Saved(0)
502    }
503}
504
505impl sum_tree::Summary for ChannelMessageSummary {
506    type Context = ();
507
508    fn add_summary(&mut self, summary: &Self, _: &()) {
509        self.max_id = summary.max_id;
510        self.count += summary.count;
511    }
512}
513
514impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
515    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
516        debug_assert!(summary.max_id > *self);
517        *self = summary.max_id;
518    }
519}
520
521impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
522    fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
523        self.0 += summary.count;
524    }
525}
526
527impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
528    fn cmp(&self, other: &Self, _: &()) -> std::cmp::Ordering {
529        Ord::cmp(&self.0, &other.0)
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::test::{FakeHttpClient, FakeServer};
537    use gpui::TestAppContext;
538    use surf::http::Response;
539
540    #[gpui::test]
541    async fn test_channel_messages(mut cx: TestAppContext) {
542        let user_id = 5;
543        let mut client = Client::new();
544        let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
545        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
546        let user_store = UserStore::new(client.clone(), http_client, cx.background().as_ref());
547
548        let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
549        channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
550
551        let get_users = server.receive::<proto::GetUsers>().await.unwrap();
552        assert_eq!(get_users.payload.user_ids, vec![5]);
553        server
554            .respond(
555                get_users.receipt(),
556                proto::GetUsersResponse {
557                    users: vec![proto::User {
558                        id: 5,
559                        github_login: "nathansobo".into(),
560                        avatar_url: "http://avatar.com/nathansobo".into(),
561                    }],
562                },
563            )
564            .await;
565
566        // Get the available channels.
567        let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
568        server
569            .respond(
570                get_channels.receipt(),
571                proto::GetChannelsResponse {
572                    channels: vec![proto::Channel {
573                        id: 5,
574                        name: "the-channel".to_string(),
575                    }],
576                },
577            )
578            .await;
579        channel_list.next_notification(&cx).await;
580        channel_list.read_with(&cx, |list, _| {
581            assert_eq!(
582                list.available_channels().unwrap(),
583                &[ChannelDetails {
584                    id: 5,
585                    name: "the-channel".into(),
586                }]
587            )
588        });
589
590        // Join a channel and populate its existing messages.
591        let channel = channel_list
592            .update(&mut cx, |list, cx| {
593                let channel_id = list.available_channels().unwrap()[0].id;
594                list.get_channel(channel_id, cx)
595            })
596            .unwrap();
597        channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
598        let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
599        server
600            .respond(
601                join_channel.receipt(),
602                proto::JoinChannelResponse {
603                    messages: vec![
604                        proto::ChannelMessage {
605                            id: 10,
606                            body: "a".into(),
607                            timestamp: 1000,
608                            sender_id: 5,
609                        },
610                        proto::ChannelMessage {
611                            id: 11,
612                            body: "b".into(),
613                            timestamp: 1001,
614                            sender_id: 6,
615                        },
616                    ],
617                    done: false,
618                },
619            )
620            .await;
621
622        // Client requests all users for the received messages
623        let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
624        get_users.payload.user_ids.sort();
625        assert_eq!(get_users.payload.user_ids, vec![6]);
626        server
627            .respond(
628                get_users.receipt(),
629                proto::GetUsersResponse {
630                    users: vec![proto::User {
631                        id: 6,
632                        github_login: "maxbrunsfeld".into(),
633                        avatar_url: "http://avatar.com/maxbrunsfeld".into(),
634                    }],
635                },
636            )
637            .await;
638
639        assert_eq!(
640            channel.next_event(&cx).await,
641            ChannelEvent::MessagesUpdated {
642                old_range: 0..0,
643                new_count: 2,
644            }
645        );
646        channel.read_with(&cx, |channel, _| {
647            assert_eq!(
648                channel
649                    .messages_in_range(0..2)
650                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
651                    .collect::<Vec<_>>(),
652                &[
653                    ("nathansobo".into(), "a".into()),
654                    ("maxbrunsfeld".into(), "b".into())
655                ]
656            );
657        });
658
659        // Receive a new message.
660        server
661            .send(proto::ChannelMessageSent {
662                channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
663                message: Some(proto::ChannelMessage {
664                    id: 12,
665                    body: "c".into(),
666                    timestamp: 1002,
667                    sender_id: 7,
668                }),
669            })
670            .await;
671
672        // Client requests user for message since they haven't seen them yet
673        let get_users = server.receive::<proto::GetUsers>().await.unwrap();
674        assert_eq!(get_users.payload.user_ids, vec![7]);
675        server
676            .respond(
677                get_users.receipt(),
678                proto::GetUsersResponse {
679                    users: vec![proto::User {
680                        id: 7,
681                        github_login: "as-cii".into(),
682                        avatar_url: "http://avatar.com/as-cii".into(),
683                    }],
684                },
685            )
686            .await;
687
688        assert_eq!(
689            channel.next_event(&cx).await,
690            ChannelEvent::MessagesUpdated {
691                old_range: 2..2,
692                new_count: 1,
693            }
694        );
695        channel.read_with(&cx, |channel, _| {
696            assert_eq!(
697                channel
698                    .messages_in_range(2..3)
699                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
700                    .collect::<Vec<_>>(),
701                &[("as-cii".into(), "c".into())]
702            )
703        });
704
705        // Scroll up to view older messages.
706        channel.update(&mut cx, |channel, cx| {
707            assert!(channel.load_more_messages(cx));
708        });
709        let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
710        assert_eq!(get_messages.payload.channel_id, 5);
711        assert_eq!(get_messages.payload.before_message_id, 10);
712        server
713            .respond(
714                get_messages.receipt(),
715                proto::GetChannelMessagesResponse {
716                    done: true,
717                    messages: vec![
718                        proto::ChannelMessage {
719                            id: 8,
720                            body: "y".into(),
721                            timestamp: 998,
722                            sender_id: 5,
723                        },
724                        proto::ChannelMessage {
725                            id: 9,
726                            body: "z".into(),
727                            timestamp: 999,
728                            sender_id: 6,
729                        },
730                    ],
731                },
732            )
733            .await;
734
735        assert_eq!(
736            channel.next_event(&cx).await,
737            ChannelEvent::MessagesUpdated {
738                old_range: 0..0,
739                new_count: 2,
740            }
741        );
742        channel.read_with(&cx, |channel, _| {
743            assert_eq!(
744                channel
745                    .messages_in_range(0..2)
746                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
747                    .collect::<Vec<_>>(),
748                &[
749                    ("nathansobo".into(), "y".into()),
750                    ("maxbrunsfeld".into(), "z".into())
751                ]
752            );
753        });
754    }
755}