1use super::{
2 proto,
3 user::{User, UserStore},
4 Client, Status, Subscription, TypedEnvelope,
5};
6use anyhow::{anyhow, Context, Result};
7use gpui::{
8 AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
9};
10use postage::prelude::Stream;
11use rand::prelude::*;
12use std::{
13 collections::{HashMap, HashSet},
14 mem,
15 ops::Range,
16 sync::Arc,
17};
18use sum_tree::{Bias, SumTree};
19use time::OffsetDateTime;
20use util::{post_inc, ResultExt as _, TryFutureExt};
21
22pub struct ChannelList {
23 available_channels: Option<Vec<ChannelDetails>>,
24 channels: HashMap<u64, WeakModelHandle<Channel>>,
25 client: Arc<Client>,
26 user_store: ModelHandle<UserStore>,
27 _task: Task<Option<()>>,
28}
29
30#[derive(Clone, Debug, PartialEq)]
31pub struct ChannelDetails {
32 pub id: u64,
33 pub name: String,
34}
35
36pub struct Channel {
37 details: ChannelDetails,
38 messages: SumTree<ChannelMessage>,
39 loaded_all_messages: bool,
40 next_pending_message_id: usize,
41 user_store: ModelHandle<UserStore>,
42 rpc: Arc<Client>,
43 rng: StdRng,
44 _subscription: Subscription,
45}
46
47#[derive(Clone, Debug)]
48pub struct ChannelMessage {
49 pub id: ChannelMessageId,
50 pub body: String,
51 pub timestamp: OffsetDateTime,
52 pub sender: Arc<User>,
53 pub nonce: u128,
54}
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
57pub enum ChannelMessageId {
58 Saved(u64),
59 Pending(usize),
60}
61
62#[derive(Clone, Debug, Default)]
63pub struct ChannelMessageSummary {
64 max_id: ChannelMessageId,
65 count: usize,
66}
67
68#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71pub enum ChannelListEvent {}
72
73#[derive(Clone, Debug, PartialEq)]
74pub enum ChannelEvent {
75 MessagesUpdated {
76 old_range: Range<usize>,
77 new_count: usize,
78 },
79}
80
81impl Entity for ChannelList {
82 type Event = ChannelListEvent;
83}
84
85impl ChannelList {
86 pub fn new(
87 user_store: ModelHandle<UserStore>,
88 rpc: Arc<Client>,
89 cx: &mut ModelContext<Self>,
90 ) -> Self {
91 let _task = cx.spawn_weak(|this, mut cx| {
92 let rpc = rpc.clone();
93 async move {
94 let mut status = rpc.status();
95 while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
96 match status {
97 Status::Connected { .. } => {
98 let response = rpc
99 .request(proto::GetChannels {})
100 .await
101 .context("failed to fetch available channels")?;
102 this.update(&mut cx, |this, cx| {
103 this.available_channels =
104 Some(response.channels.into_iter().map(Into::into).collect());
105
106 let mut to_remove = Vec::new();
107 for (channel_id, channel) in &this.channels {
108 if let Some(channel) = channel.upgrade(cx) {
109 channel.update(cx, |channel, cx| channel.rejoin(cx))
110 } else {
111 to_remove.push(*channel_id);
112 }
113 }
114
115 for channel_id in to_remove {
116 this.channels.remove(&channel_id);
117 }
118 cx.notify();
119 });
120 }
121 Status::SignedOut { .. } => {
122 this.update(&mut cx, |this, cx| {
123 this.available_channels = None;
124 this.channels.clear();
125 cx.notify();
126 });
127 }
128 _ => {}
129 }
130 }
131 Ok(())
132 }
133 .log_err()
134 });
135
136 Self {
137 available_channels: None,
138 channels: Default::default(),
139 user_store,
140 client: rpc,
141 _task,
142 }
143 }
144
145 pub fn available_channels(&self) -> Option<&[ChannelDetails]> {
146 self.available_channels.as_ref().map(Vec::as_slice)
147 }
148
149 pub fn get_channel(
150 &mut self,
151 id: u64,
152 cx: &mut MutableAppContext,
153 ) -> Option<ModelHandle<Channel>> {
154 if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) {
155 return Some(channel);
156 }
157
158 let channels = self.available_channels.as_ref()?;
159 let details = channels.iter().find(|details| details.id == id)?.clone();
160 let channel = cx.add_model(|cx| {
161 Channel::new(details, self.user_store.clone(), self.client.clone(), cx)
162 });
163 self.channels.insert(id, channel.downgrade());
164 Some(channel)
165 }
166}
167
168impl Entity for Channel {
169 type Event = ChannelEvent;
170
171 fn release(&mut self, _: &mut MutableAppContext) {
172 self.rpc
173 .send(proto::LeaveChannel {
174 channel_id: self.details.id,
175 })
176 .log_err();
177 }
178}
179
180impl Channel {
181 pub fn new(
182 details: ChannelDetails,
183 user_store: ModelHandle<UserStore>,
184 rpc: Arc<Client>,
185 cx: &mut ModelContext<Self>,
186 ) -> Self {
187 let _subscription = rpc.subscribe_to_entity(details.id, cx, Self::handle_message_sent);
188
189 {
190 let user_store = user_store.clone();
191 let rpc = rpc.clone();
192 let channel_id = details.id;
193 cx.spawn(|channel, mut cx| {
194 async move {
195 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
196 let messages =
197 messages_from_proto(response.messages, &user_store, &mut cx).await?;
198 let loaded_all_messages = response.done;
199
200 channel.update(&mut cx, |channel, cx| {
201 channel.insert_messages(messages, cx);
202 channel.loaded_all_messages = loaded_all_messages;
203 });
204
205 Ok(())
206 }
207 .log_err()
208 })
209 .detach();
210 }
211
212 Self {
213 details,
214 user_store,
215 rpc,
216 messages: Default::default(),
217 loaded_all_messages: false,
218 next_pending_message_id: 0,
219 rng: StdRng::from_entropy(),
220 _subscription,
221 }
222 }
223
224 pub fn name(&self) -> &str {
225 &self.details.name
226 }
227
228 pub fn send_message(
229 &mut self,
230 body: String,
231 cx: &mut ModelContext<Self>,
232 ) -> Result<Task<Result<()>>> {
233 if body.is_empty() {
234 Err(anyhow!("message body can't be empty"))?;
235 }
236
237 let current_user = self
238 .user_store
239 .read(cx)
240 .current_user()
241 .ok_or_else(|| anyhow!("current_user is not present"))?;
242
243 let channel_id = self.details.id;
244 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
245 let nonce = self.rng.gen();
246 self.insert_messages(
247 SumTree::from_item(
248 ChannelMessage {
249 id: pending_id,
250 body: body.clone(),
251 sender: current_user,
252 timestamp: OffsetDateTime::now_utc(),
253 nonce,
254 },
255 &(),
256 ),
257 cx,
258 );
259 let user_store = self.user_store.clone();
260 let rpc = self.rpc.clone();
261 Ok(cx.spawn(|this, mut cx| async move {
262 let request = rpc.request(proto::SendChannelMessage {
263 channel_id,
264 body,
265 nonce: Some(nonce.into()),
266 });
267 let response = request.await?;
268 let message = ChannelMessage::from_proto(
269 response.message.ok_or_else(|| anyhow!("invalid message"))?,
270 &user_store,
271 &mut cx,
272 )
273 .await?;
274 this.update(&mut cx, |this, cx| {
275 this.insert_messages(SumTree::from_item(message, &()), cx);
276 Ok(())
277 })
278 }))
279 }
280
281 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
282 if !self.loaded_all_messages {
283 let rpc = self.rpc.clone();
284 let user_store = self.user_store.clone();
285 let channel_id = self.details.id;
286 if let Some(before_message_id) =
287 self.messages.first().and_then(|message| match message.id {
288 ChannelMessageId::Saved(id) => Some(id),
289 ChannelMessageId::Pending(_) => None,
290 })
291 {
292 cx.spawn(|this, mut cx| {
293 async move {
294 let response = rpc
295 .request(proto::GetChannelMessages {
296 channel_id,
297 before_message_id,
298 })
299 .await?;
300 let loaded_all_messages = response.done;
301 let messages =
302 messages_from_proto(response.messages, &user_store, &mut cx).await?;
303 this.update(&mut cx, |this, cx| {
304 this.loaded_all_messages = loaded_all_messages;
305 this.insert_messages(messages, cx);
306 });
307 Ok(())
308 }
309 .log_err()
310 })
311 .detach();
312 return true;
313 }
314 }
315 false
316 }
317
318 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
319 let user_store = self.user_store.clone();
320 let rpc = self.rpc.clone();
321 let channel_id = self.details.id;
322 cx.spawn(|this, mut cx| {
323 async move {
324 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
325 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
326 let loaded_all_messages = response.done;
327
328 let pending_messages = this.update(&mut cx, |this, cx| {
329 if let Some((first_new_message, last_old_message)) =
330 messages.first().zip(this.messages.last())
331 {
332 if first_new_message.id > last_old_message.id {
333 let old_messages = mem::take(&mut this.messages);
334 cx.emit(ChannelEvent::MessagesUpdated {
335 old_range: 0..old_messages.summary().count,
336 new_count: 0,
337 });
338 this.loaded_all_messages = loaded_all_messages;
339 }
340 }
341
342 this.insert_messages(messages, cx);
343 if loaded_all_messages {
344 this.loaded_all_messages = loaded_all_messages;
345 }
346
347 this.pending_messages().cloned().collect::<Vec<_>>()
348 });
349
350 for pending_message in pending_messages {
351 let request = rpc.request(proto::SendChannelMessage {
352 channel_id,
353 body: pending_message.body,
354 nonce: Some(pending_message.nonce.into()),
355 });
356 let response = request.await?;
357 let message = ChannelMessage::from_proto(
358 response.message.ok_or_else(|| anyhow!("invalid message"))?,
359 &user_store,
360 &mut cx,
361 )
362 .await?;
363 this.update(&mut cx, |this, cx| {
364 this.insert_messages(SumTree::from_item(message, &()), cx);
365 });
366 }
367
368 Ok(())
369 }
370 .log_err()
371 })
372 .detach();
373 }
374
375 pub fn message_count(&self) -> usize {
376 self.messages.summary().count
377 }
378
379 pub fn messages(&self) -> &SumTree<ChannelMessage> {
380 &self.messages
381 }
382
383 pub fn message(&self, ix: usize) -> &ChannelMessage {
384 let mut cursor = self.messages.cursor::<Count>();
385 cursor.seek(&Count(ix), Bias::Right, &());
386 cursor.item().unwrap()
387 }
388
389 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
390 let mut cursor = self.messages.cursor::<Count>();
391 cursor.seek(&Count(range.start), Bias::Right, &());
392 cursor.take(range.len())
393 }
394
395 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
396 let mut cursor = self.messages.cursor::<ChannelMessageId>();
397 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
398 cursor
399 }
400
401 fn handle_message_sent(
402 &mut self,
403 message: TypedEnvelope<proto::ChannelMessageSent>,
404 _: Arc<Client>,
405 cx: &mut ModelContext<Self>,
406 ) -> Result<()> {
407 let user_store = self.user_store.clone();
408 let message = message
409 .payload
410 .message
411 .ok_or_else(|| anyhow!("empty message"))?;
412
413 cx.spawn(|this, mut cx| {
414 async move {
415 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
416 this.update(&mut cx, |this, cx| {
417 this.insert_messages(SumTree::from_item(message, &()), cx)
418 });
419 Ok(())
420 }
421 .log_err()
422 })
423 .detach();
424 Ok(())
425 }
426
427 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
428 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
429 let nonces = messages
430 .cursor::<()>()
431 .map(|m| m.nonce)
432 .collect::<HashSet<_>>();
433
434 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
435 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
436 let start_ix = old_cursor.start().1 .0;
437 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
438 let removed_count = removed_messages.summary().count;
439 let new_count = messages.summary().count;
440 let end_ix = start_ix + removed_count;
441
442 new_messages.push_tree(messages, &());
443
444 let mut ranges = Vec::<Range<usize>>::new();
445 if new_messages.last().unwrap().is_pending() {
446 new_messages.push_tree(old_cursor.suffix(&()), &());
447 } else {
448 new_messages.push_tree(
449 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
450 &(),
451 );
452
453 while let Some(message) = old_cursor.item() {
454 let message_ix = old_cursor.start().1 .0;
455 if nonces.contains(&message.nonce) {
456 if ranges.last().map_or(false, |r| r.end == message_ix) {
457 ranges.last_mut().unwrap().end += 1;
458 } else {
459 ranges.push(message_ix..message_ix + 1);
460 }
461 } else {
462 new_messages.push(message.clone(), &());
463 }
464 old_cursor.next(&());
465 }
466 }
467
468 drop(old_cursor);
469 self.messages = new_messages;
470
471 for range in ranges.into_iter().rev() {
472 cx.emit(ChannelEvent::MessagesUpdated {
473 old_range: range,
474 new_count: 0,
475 });
476 }
477 cx.emit(ChannelEvent::MessagesUpdated {
478 old_range: start_ix..end_ix,
479 new_count,
480 });
481 cx.notify();
482 }
483 }
484}
485
486async fn messages_from_proto(
487 proto_messages: Vec<proto::ChannelMessage>,
488 user_store: &ModelHandle<UserStore>,
489 cx: &mut AsyncAppContext,
490) -> Result<SumTree<ChannelMessage>> {
491 let unique_user_ids = proto_messages
492 .iter()
493 .map(|m| m.sender_id)
494 .collect::<HashSet<_>>()
495 .into_iter()
496 .collect();
497 user_store
498 .update(cx, |user_store, cx| {
499 user_store.load_users(unique_user_ids, cx)
500 })
501 .await?;
502
503 let mut messages = Vec::with_capacity(proto_messages.len());
504 for message in proto_messages {
505 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
506 }
507 let mut result = SumTree::new();
508 result.extend(messages, &());
509 Ok(result)
510}
511
512impl From<proto::Channel> for ChannelDetails {
513 fn from(message: proto::Channel) -> Self {
514 Self {
515 id: message.id,
516 name: message.name,
517 }
518 }
519}
520
521impl ChannelMessage {
522 pub async fn from_proto(
523 message: proto::ChannelMessage,
524 user_store: &ModelHandle<UserStore>,
525 cx: &mut AsyncAppContext,
526 ) -> Result<Self> {
527 let sender = user_store
528 .update(cx, |user_store, cx| {
529 user_store.fetch_user(message.sender_id, cx)
530 })
531 .await?;
532 Ok(ChannelMessage {
533 id: ChannelMessageId::Saved(message.id),
534 body: message.body,
535 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
536 sender,
537 nonce: message
538 .nonce
539 .ok_or_else(|| anyhow!("nonce is required"))?
540 .into(),
541 })
542 }
543
544 pub fn is_pending(&self) -> bool {
545 matches!(self.id, ChannelMessageId::Pending(_))
546 }
547}
548
549impl sum_tree::Item for ChannelMessage {
550 type Summary = ChannelMessageSummary;
551
552 fn summary(&self) -> Self::Summary {
553 ChannelMessageSummary {
554 max_id: self.id,
555 count: 1,
556 }
557 }
558}
559
560impl Default for ChannelMessageId {
561 fn default() -> Self {
562 Self::Saved(0)
563 }
564}
565
566impl sum_tree::Summary for ChannelMessageSummary {
567 type Context = ();
568
569 fn add_summary(&mut self, summary: &Self, _: &()) {
570 self.max_id = summary.max_id;
571 self.count += summary.count;
572 }
573}
574
575impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
576 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
577 debug_assert!(summary.max_id > *self);
578 *self = summary.max_id;
579 }
580}
581
582impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
583 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
584 self.0 += summary.count;
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use crate::test::{FakeHttpClient, FakeServer};
592 use gpui::TestAppContext;
593 use surf::http::Response;
594
595 #[gpui::test]
596 async fn test_channel_messages(mut cx: TestAppContext) {
597 let user_id = 5;
598 let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
599 let mut client = Client::new(http_client.clone());
600 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
601 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx));
602
603 let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
604 channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
605
606 // Get the available channels.
607 let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
608 server
609 .respond(
610 get_channels.receipt(),
611 proto::GetChannelsResponse {
612 channels: vec![proto::Channel {
613 id: 5,
614 name: "the-channel".to_string(),
615 }],
616 },
617 )
618 .await;
619 channel_list.next_notification(&cx).await;
620 channel_list.read_with(&cx, |list, _| {
621 assert_eq!(
622 list.available_channels().unwrap(),
623 &[ChannelDetails {
624 id: 5,
625 name: "the-channel".into(),
626 }]
627 )
628 });
629
630 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
631 assert_eq!(get_users.payload.user_ids, vec![5]);
632 server
633 .respond(
634 get_users.receipt(),
635 proto::GetUsersResponse {
636 users: vec![proto::User {
637 id: 5,
638 github_login: "nathansobo".into(),
639 avatar_url: "http://avatar.com/nathansobo".into(),
640 }],
641 },
642 )
643 .await;
644
645 // Join a channel and populate its existing messages.
646 let channel = channel_list
647 .update(&mut cx, |list, cx| {
648 let channel_id = list.available_channels().unwrap()[0].id;
649 list.get_channel(channel_id, cx)
650 })
651 .unwrap();
652 channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
653 let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
654 server
655 .respond(
656 join_channel.receipt(),
657 proto::JoinChannelResponse {
658 messages: vec![
659 proto::ChannelMessage {
660 id: 10,
661 body: "a".into(),
662 timestamp: 1000,
663 sender_id: 5,
664 nonce: Some(1.into()),
665 },
666 proto::ChannelMessage {
667 id: 11,
668 body: "b".into(),
669 timestamp: 1001,
670 sender_id: 6,
671 nonce: Some(2.into()),
672 },
673 ],
674 done: false,
675 },
676 )
677 .await;
678
679 // Client requests all users for the received messages
680 let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
681 get_users.payload.user_ids.sort();
682 assert_eq!(get_users.payload.user_ids, vec![6]);
683 server
684 .respond(
685 get_users.receipt(),
686 proto::GetUsersResponse {
687 users: vec![proto::User {
688 id: 6,
689 github_login: "maxbrunsfeld".into(),
690 avatar_url: "http://avatar.com/maxbrunsfeld".into(),
691 }],
692 },
693 )
694 .await;
695
696 assert_eq!(
697 channel.next_event(&cx).await,
698 ChannelEvent::MessagesUpdated {
699 old_range: 0..0,
700 new_count: 2,
701 }
702 );
703 channel.read_with(&cx, |channel, _| {
704 assert_eq!(
705 channel
706 .messages_in_range(0..2)
707 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
708 .collect::<Vec<_>>(),
709 &[
710 ("nathansobo".into(), "a".into()),
711 ("maxbrunsfeld".into(), "b".into())
712 ]
713 );
714 });
715
716 // Receive a new message.
717 server.send(proto::ChannelMessageSent {
718 channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
719 message: Some(proto::ChannelMessage {
720 id: 12,
721 body: "c".into(),
722 timestamp: 1002,
723 sender_id: 7,
724 nonce: Some(3.into()),
725 }),
726 });
727
728 // Client requests user for message since they haven't seen them yet
729 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
730 assert_eq!(get_users.payload.user_ids, vec![7]);
731 server
732 .respond(
733 get_users.receipt(),
734 proto::GetUsersResponse {
735 users: vec![proto::User {
736 id: 7,
737 github_login: "as-cii".into(),
738 avatar_url: "http://avatar.com/as-cii".into(),
739 }],
740 },
741 )
742 .await;
743
744 assert_eq!(
745 channel.next_event(&cx).await,
746 ChannelEvent::MessagesUpdated {
747 old_range: 2..2,
748 new_count: 1,
749 }
750 );
751 channel.read_with(&cx, |channel, _| {
752 assert_eq!(
753 channel
754 .messages_in_range(2..3)
755 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
756 .collect::<Vec<_>>(),
757 &[("as-cii".into(), "c".into())]
758 )
759 });
760
761 // Scroll up to view older messages.
762 channel.update(&mut cx, |channel, cx| {
763 assert!(channel.load_more_messages(cx));
764 });
765 let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
766 assert_eq!(get_messages.payload.channel_id, 5);
767 assert_eq!(get_messages.payload.before_message_id, 10);
768 server
769 .respond(
770 get_messages.receipt(),
771 proto::GetChannelMessagesResponse {
772 done: true,
773 messages: vec![
774 proto::ChannelMessage {
775 id: 8,
776 body: "y".into(),
777 timestamp: 998,
778 sender_id: 5,
779 nonce: Some(4.into()),
780 },
781 proto::ChannelMessage {
782 id: 9,
783 body: "z".into(),
784 timestamp: 999,
785 sender_id: 6,
786 nonce: Some(5.into()),
787 },
788 ],
789 },
790 )
791 .await;
792
793 assert_eq!(
794 channel.next_event(&cx).await,
795 ChannelEvent::MessagesUpdated {
796 old_range: 0..0,
797 new_count: 2,
798 }
799 );
800 channel.read_with(&cx, |channel, _| {
801 assert_eq!(
802 channel
803 .messages_in_range(0..2)
804 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
805 .collect::<Vec<_>>(),
806 &[
807 ("nathansobo".into(), "y".into()),
808 ("maxbrunsfeld".into(), "z".into())
809 ]
810 );
811 });
812 }
813}