1use super::{
2 proto,
3 user::{User, UserStore},
4 Client, Status, Subscription, TypedEnvelope,
5};
6use anyhow::{anyhow, Context, Result};
7use gpui::{
8 AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
9};
10use postage::prelude::Stream;
11use rand::prelude::*;
12use std::{
13 collections::{HashMap, HashSet},
14 mem,
15 ops::Range,
16 sync::Arc,
17};
18use sum_tree::{Bias, SumTree};
19use time::OffsetDateTime;
20use util::{post_inc, ResultExt as _, TryFutureExt};
21
22pub struct ChannelList {
23 available_channels: Option<Vec<ChannelDetails>>,
24 channels: HashMap<u64, WeakModelHandle<Channel>>,
25 client: Arc<Client>,
26 user_store: ModelHandle<UserStore>,
27 _task: Task<Option<()>>,
28}
29
30#[derive(Clone, Debug, PartialEq)]
31pub struct ChannelDetails {
32 pub id: u64,
33 pub name: String,
34}
35
36pub struct Channel {
37 details: ChannelDetails,
38 messages: SumTree<ChannelMessage>,
39 loaded_all_messages: bool,
40 next_pending_message_id: usize,
41 user_store: ModelHandle<UserStore>,
42 rpc: Arc<Client>,
43 rng: StdRng,
44 _subscription: Subscription,
45}
46
47#[derive(Clone, Debug)]
48pub struct ChannelMessage {
49 pub id: ChannelMessageId,
50 pub body: String,
51 pub timestamp: OffsetDateTime,
52 pub sender: Arc<User>,
53 pub nonce: u128,
54}
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
57pub enum ChannelMessageId {
58 Saved(u64),
59 Pending(usize),
60}
61
62#[derive(Clone, Debug, Default)]
63pub struct ChannelMessageSummary {
64 max_id: ChannelMessageId,
65 count: usize,
66}
67
68#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71pub enum ChannelListEvent {}
72
73#[derive(Clone, Debug, PartialEq)]
74pub enum ChannelEvent {
75 MessagesUpdated {
76 old_range: Range<usize>,
77 new_count: usize,
78 },
79}
80
81impl Entity for ChannelList {
82 type Event = ChannelListEvent;
83}
84
85impl ChannelList {
86 pub fn new(
87 user_store: ModelHandle<UserStore>,
88 rpc: Arc<Client>,
89 cx: &mut ModelContext<Self>,
90 ) -> Self {
91 let _task = cx.spawn_weak(|this, mut cx| {
92 let rpc = rpc.clone();
93 async move {
94 let mut status = rpc.status();
95 while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
96 match status {
97 Status::Connected { .. } => {
98 let response = rpc
99 .request(proto::GetChannels {})
100 .await
101 .context("failed to fetch available channels")?;
102 this.update(&mut cx, |this, cx| {
103 this.available_channels =
104 Some(response.channels.into_iter().map(Into::into).collect());
105
106 let mut to_remove = Vec::new();
107 for (channel_id, channel) in &this.channels {
108 if let Some(channel) = channel.upgrade(cx) {
109 channel.update(cx, |channel, cx| channel.rejoin(cx))
110 } else {
111 to_remove.push(*channel_id);
112 }
113 }
114
115 for channel_id in to_remove {
116 this.channels.remove(&channel_id);
117 }
118 cx.notify();
119 });
120 }
121 Status::SignedOut { .. } => {
122 this.update(&mut cx, |this, cx| {
123 this.available_channels = None;
124 this.channels.clear();
125 cx.notify();
126 });
127 }
128 _ => {}
129 }
130 }
131 Ok(())
132 }
133 .log_err()
134 });
135
136 Self {
137 available_channels: None,
138 channels: Default::default(),
139 user_store,
140 client: rpc,
141 _task,
142 }
143 }
144
145 pub fn available_channels(&self) -> Option<&[ChannelDetails]> {
146 self.available_channels.as_ref().map(Vec::as_slice)
147 }
148
149 pub fn get_channel(
150 &mut self,
151 id: u64,
152 cx: &mut MutableAppContext,
153 ) -> Option<ModelHandle<Channel>> {
154 if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) {
155 return Some(channel);
156 }
157
158 let channels = self.available_channels.as_ref()?;
159 let details = channels.iter().find(|details| details.id == id)?.clone();
160 let channel = cx.add_model(|cx| {
161 Channel::new(details, self.user_store.clone(), self.client.clone(), cx)
162 });
163 self.channels.insert(id, channel.downgrade());
164 Some(channel)
165 }
166}
167
168impl Entity for Channel {
169 type Event = ChannelEvent;
170
171 fn release(&mut self, _: &mut MutableAppContext) {
172 self.rpc
173 .send(proto::LeaveChannel {
174 channel_id: self.details.id,
175 })
176 .log_err();
177 }
178}
179
180impl Channel {
181 pub fn new(
182 details: ChannelDetails,
183 user_store: ModelHandle<UserStore>,
184 rpc: Arc<Client>,
185 cx: &mut ModelContext<Self>,
186 ) -> Self {
187 let _subscription =
188 rpc.add_entity_message_handler(details.id, cx, Self::handle_message_sent);
189
190 {
191 let user_store = user_store.clone();
192 let rpc = rpc.clone();
193 let channel_id = details.id;
194 cx.spawn(|channel, mut cx| {
195 async move {
196 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
197 let messages =
198 messages_from_proto(response.messages, &user_store, &mut cx).await?;
199 let loaded_all_messages = response.done;
200
201 channel.update(&mut cx, |channel, cx| {
202 channel.insert_messages(messages, cx);
203 channel.loaded_all_messages = loaded_all_messages;
204 });
205
206 Ok(())
207 }
208 .log_err()
209 })
210 .detach();
211 }
212
213 Self {
214 details,
215 user_store,
216 rpc,
217 messages: Default::default(),
218 loaded_all_messages: false,
219 next_pending_message_id: 0,
220 rng: StdRng::from_entropy(),
221 _subscription,
222 }
223 }
224
225 pub fn name(&self) -> &str {
226 &self.details.name
227 }
228
229 pub fn send_message(
230 &mut self,
231 body: String,
232 cx: &mut ModelContext<Self>,
233 ) -> Result<Task<Result<()>>> {
234 if body.is_empty() {
235 Err(anyhow!("message body can't be empty"))?;
236 }
237
238 let current_user = self
239 .user_store
240 .read(cx)
241 .current_user()
242 .ok_or_else(|| anyhow!("current_user is not present"))?;
243
244 let channel_id = self.details.id;
245 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
246 let nonce = self.rng.gen();
247 self.insert_messages(
248 SumTree::from_item(
249 ChannelMessage {
250 id: pending_id,
251 body: body.clone(),
252 sender: current_user,
253 timestamp: OffsetDateTime::now_utc(),
254 nonce,
255 },
256 &(),
257 ),
258 cx,
259 );
260 let user_store = self.user_store.clone();
261 let rpc = self.rpc.clone();
262 Ok(cx.spawn(|this, mut cx| async move {
263 let request = rpc.request(proto::SendChannelMessage {
264 channel_id,
265 body,
266 nonce: Some(nonce.into()),
267 });
268 let response = request.await?;
269 let message = ChannelMessage::from_proto(
270 response.message.ok_or_else(|| anyhow!("invalid message"))?,
271 &user_store,
272 &mut cx,
273 )
274 .await?;
275 this.update(&mut cx, |this, cx| {
276 this.insert_messages(SumTree::from_item(message, &()), cx);
277 Ok(())
278 })
279 }))
280 }
281
282 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
283 if !self.loaded_all_messages {
284 let rpc = self.rpc.clone();
285 let user_store = self.user_store.clone();
286 let channel_id = self.details.id;
287 if let Some(before_message_id) =
288 self.messages.first().and_then(|message| match message.id {
289 ChannelMessageId::Saved(id) => Some(id),
290 ChannelMessageId::Pending(_) => None,
291 })
292 {
293 cx.spawn(|this, mut cx| {
294 async move {
295 let response = rpc
296 .request(proto::GetChannelMessages {
297 channel_id,
298 before_message_id,
299 })
300 .await?;
301 let loaded_all_messages = response.done;
302 let messages =
303 messages_from_proto(response.messages, &user_store, &mut cx).await?;
304 this.update(&mut cx, |this, cx| {
305 this.loaded_all_messages = loaded_all_messages;
306 this.insert_messages(messages, cx);
307 });
308 Ok(())
309 }
310 .log_err()
311 })
312 .detach();
313 return true;
314 }
315 }
316 false
317 }
318
319 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
320 let user_store = self.user_store.clone();
321 let rpc = self.rpc.clone();
322 let channel_id = self.details.id;
323 cx.spawn(|this, mut cx| {
324 async move {
325 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
326 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
327 let loaded_all_messages = response.done;
328
329 let pending_messages = this.update(&mut cx, |this, cx| {
330 if let Some((first_new_message, last_old_message)) =
331 messages.first().zip(this.messages.last())
332 {
333 if first_new_message.id > last_old_message.id {
334 let old_messages = mem::take(&mut this.messages);
335 cx.emit(ChannelEvent::MessagesUpdated {
336 old_range: 0..old_messages.summary().count,
337 new_count: 0,
338 });
339 this.loaded_all_messages = loaded_all_messages;
340 }
341 }
342
343 this.insert_messages(messages, cx);
344 if loaded_all_messages {
345 this.loaded_all_messages = loaded_all_messages;
346 }
347
348 this.pending_messages().cloned().collect::<Vec<_>>()
349 });
350
351 for pending_message in pending_messages {
352 let request = rpc.request(proto::SendChannelMessage {
353 channel_id,
354 body: pending_message.body,
355 nonce: Some(pending_message.nonce.into()),
356 });
357 let response = request.await?;
358 let message = ChannelMessage::from_proto(
359 response.message.ok_or_else(|| anyhow!("invalid message"))?,
360 &user_store,
361 &mut cx,
362 )
363 .await?;
364 this.update(&mut cx, |this, cx| {
365 this.insert_messages(SumTree::from_item(message, &()), cx);
366 });
367 }
368
369 Ok(())
370 }
371 .log_err()
372 })
373 .detach();
374 }
375
376 pub fn message_count(&self) -> usize {
377 self.messages.summary().count
378 }
379
380 pub fn messages(&self) -> &SumTree<ChannelMessage> {
381 &self.messages
382 }
383
384 pub fn message(&self, ix: usize) -> &ChannelMessage {
385 let mut cursor = self.messages.cursor::<Count>();
386 cursor.seek(&Count(ix), Bias::Right, &());
387 cursor.item().unwrap()
388 }
389
390 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
391 let mut cursor = self.messages.cursor::<Count>();
392 cursor.seek(&Count(range.start), Bias::Right, &());
393 cursor.take(range.len())
394 }
395
396 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
397 let mut cursor = self.messages.cursor::<ChannelMessageId>();
398 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
399 cursor
400 }
401
402 async fn handle_message_sent(
403 this: ModelHandle<Self>,
404 message: TypedEnvelope<proto::ChannelMessageSent>,
405 _: Arc<Client>,
406 mut cx: AsyncAppContext,
407 ) -> Result<()> {
408 let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
409 let message = message
410 .payload
411 .message
412 .ok_or_else(|| anyhow!("empty message"))?;
413
414 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
415 this.update(&mut cx, |this, cx| {
416 this.insert_messages(SumTree::from_item(message, &()), cx)
417 });
418
419 Ok(())
420 }
421
422 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
423 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
424 let nonces = messages
425 .cursor::<()>()
426 .map(|m| m.nonce)
427 .collect::<HashSet<_>>();
428
429 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
430 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
431 let start_ix = old_cursor.start().1 .0;
432 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
433 let removed_count = removed_messages.summary().count;
434 let new_count = messages.summary().count;
435 let end_ix = start_ix + removed_count;
436
437 new_messages.push_tree(messages, &());
438
439 let mut ranges = Vec::<Range<usize>>::new();
440 if new_messages.last().unwrap().is_pending() {
441 new_messages.push_tree(old_cursor.suffix(&()), &());
442 } else {
443 new_messages.push_tree(
444 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
445 &(),
446 );
447
448 while let Some(message) = old_cursor.item() {
449 let message_ix = old_cursor.start().1 .0;
450 if nonces.contains(&message.nonce) {
451 if ranges.last().map_or(false, |r| r.end == message_ix) {
452 ranges.last_mut().unwrap().end += 1;
453 } else {
454 ranges.push(message_ix..message_ix + 1);
455 }
456 } else {
457 new_messages.push(message.clone(), &());
458 }
459 old_cursor.next(&());
460 }
461 }
462
463 drop(old_cursor);
464 self.messages = new_messages;
465
466 for range in ranges.into_iter().rev() {
467 cx.emit(ChannelEvent::MessagesUpdated {
468 old_range: range,
469 new_count: 0,
470 });
471 }
472 cx.emit(ChannelEvent::MessagesUpdated {
473 old_range: start_ix..end_ix,
474 new_count,
475 });
476 cx.notify();
477 }
478 }
479}
480
481async fn messages_from_proto(
482 proto_messages: Vec<proto::ChannelMessage>,
483 user_store: &ModelHandle<UserStore>,
484 cx: &mut AsyncAppContext,
485) -> Result<SumTree<ChannelMessage>> {
486 let unique_user_ids = proto_messages
487 .iter()
488 .map(|m| m.sender_id)
489 .collect::<HashSet<_>>()
490 .into_iter()
491 .collect();
492 user_store
493 .update(cx, |user_store, cx| {
494 user_store.load_users(unique_user_ids, cx)
495 })
496 .await?;
497
498 let mut messages = Vec::with_capacity(proto_messages.len());
499 for message in proto_messages {
500 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
501 }
502 let mut result = SumTree::new();
503 result.extend(messages, &());
504 Ok(result)
505}
506
507impl From<proto::Channel> for ChannelDetails {
508 fn from(message: proto::Channel) -> Self {
509 Self {
510 id: message.id,
511 name: message.name,
512 }
513 }
514}
515
516impl ChannelMessage {
517 pub async fn from_proto(
518 message: proto::ChannelMessage,
519 user_store: &ModelHandle<UserStore>,
520 cx: &mut AsyncAppContext,
521 ) -> Result<Self> {
522 let sender = user_store
523 .update(cx, |user_store, cx| {
524 user_store.fetch_user(message.sender_id, cx)
525 })
526 .await?;
527 Ok(ChannelMessage {
528 id: ChannelMessageId::Saved(message.id),
529 body: message.body,
530 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
531 sender,
532 nonce: message
533 .nonce
534 .ok_or_else(|| anyhow!("nonce is required"))?
535 .into(),
536 })
537 }
538
539 pub fn is_pending(&self) -> bool {
540 matches!(self.id, ChannelMessageId::Pending(_))
541 }
542}
543
544impl sum_tree::Item for ChannelMessage {
545 type Summary = ChannelMessageSummary;
546
547 fn summary(&self) -> Self::Summary {
548 ChannelMessageSummary {
549 max_id: self.id,
550 count: 1,
551 }
552 }
553}
554
555impl Default for ChannelMessageId {
556 fn default() -> Self {
557 Self::Saved(0)
558 }
559}
560
561impl sum_tree::Summary for ChannelMessageSummary {
562 type Context = ();
563
564 fn add_summary(&mut self, summary: &Self, _: &()) {
565 self.max_id = summary.max_id;
566 self.count += summary.count;
567 }
568}
569
570impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
571 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
572 debug_assert!(summary.max_id > *self);
573 *self = summary.max_id;
574 }
575}
576
577impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
578 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
579 self.0 += summary.count;
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use crate::test::{FakeHttpClient, FakeServer};
587 use gpui::TestAppContext;
588 use surf::http::Response;
589
590 #[gpui::test]
591 async fn test_channel_messages(mut cx: TestAppContext) {
592 let user_id = 5;
593 let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
594 let mut client = Client::new(http_client.clone());
595 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
596 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx));
597
598 let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
599 channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
600
601 // Get the available channels.
602 let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
603 server
604 .respond(
605 get_channels.receipt(),
606 proto::GetChannelsResponse {
607 channels: vec![proto::Channel {
608 id: 5,
609 name: "the-channel".to_string(),
610 }],
611 },
612 )
613 .await;
614 channel_list.next_notification(&cx).await;
615 channel_list.read_with(&cx, |list, _| {
616 assert_eq!(
617 list.available_channels().unwrap(),
618 &[ChannelDetails {
619 id: 5,
620 name: "the-channel".into(),
621 }]
622 )
623 });
624
625 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
626 assert_eq!(get_users.payload.user_ids, vec![5]);
627 server
628 .respond(
629 get_users.receipt(),
630 proto::GetUsersResponse {
631 users: vec![proto::User {
632 id: 5,
633 github_login: "nathansobo".into(),
634 avatar_url: "http://avatar.com/nathansobo".into(),
635 }],
636 },
637 )
638 .await;
639
640 // Join a channel and populate its existing messages.
641 let channel = channel_list
642 .update(&mut cx, |list, cx| {
643 let channel_id = list.available_channels().unwrap()[0].id;
644 list.get_channel(channel_id, cx)
645 })
646 .unwrap();
647 channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
648 let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
649 server
650 .respond(
651 join_channel.receipt(),
652 proto::JoinChannelResponse {
653 messages: vec![
654 proto::ChannelMessage {
655 id: 10,
656 body: "a".into(),
657 timestamp: 1000,
658 sender_id: 5,
659 nonce: Some(1.into()),
660 },
661 proto::ChannelMessage {
662 id: 11,
663 body: "b".into(),
664 timestamp: 1001,
665 sender_id: 6,
666 nonce: Some(2.into()),
667 },
668 ],
669 done: false,
670 },
671 )
672 .await;
673
674 // Client requests all users for the received messages
675 let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
676 get_users.payload.user_ids.sort();
677 assert_eq!(get_users.payload.user_ids, vec![6]);
678 server
679 .respond(
680 get_users.receipt(),
681 proto::GetUsersResponse {
682 users: vec![proto::User {
683 id: 6,
684 github_login: "maxbrunsfeld".into(),
685 avatar_url: "http://avatar.com/maxbrunsfeld".into(),
686 }],
687 },
688 )
689 .await;
690
691 assert_eq!(
692 channel.next_event(&cx).await,
693 ChannelEvent::MessagesUpdated {
694 old_range: 0..0,
695 new_count: 2,
696 }
697 );
698 channel.read_with(&cx, |channel, _| {
699 assert_eq!(
700 channel
701 .messages_in_range(0..2)
702 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
703 .collect::<Vec<_>>(),
704 &[
705 ("nathansobo".into(), "a".into()),
706 ("maxbrunsfeld".into(), "b".into())
707 ]
708 );
709 });
710
711 // Receive a new message.
712 server.send(proto::ChannelMessageSent {
713 channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
714 message: Some(proto::ChannelMessage {
715 id: 12,
716 body: "c".into(),
717 timestamp: 1002,
718 sender_id: 7,
719 nonce: Some(3.into()),
720 }),
721 });
722
723 // Client requests user for message since they haven't seen them yet
724 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
725 assert_eq!(get_users.payload.user_ids, vec![7]);
726 server
727 .respond(
728 get_users.receipt(),
729 proto::GetUsersResponse {
730 users: vec![proto::User {
731 id: 7,
732 github_login: "as-cii".into(),
733 avatar_url: "http://avatar.com/as-cii".into(),
734 }],
735 },
736 )
737 .await;
738
739 assert_eq!(
740 channel.next_event(&cx).await,
741 ChannelEvent::MessagesUpdated {
742 old_range: 2..2,
743 new_count: 1,
744 }
745 );
746 channel.read_with(&cx, |channel, _| {
747 assert_eq!(
748 channel
749 .messages_in_range(2..3)
750 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
751 .collect::<Vec<_>>(),
752 &[("as-cii".into(), "c".into())]
753 )
754 });
755
756 // Scroll up to view older messages.
757 channel.update(&mut cx, |channel, cx| {
758 assert!(channel.load_more_messages(cx));
759 });
760 let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
761 assert_eq!(get_messages.payload.channel_id, 5);
762 assert_eq!(get_messages.payload.before_message_id, 10);
763 server
764 .respond(
765 get_messages.receipt(),
766 proto::GetChannelMessagesResponse {
767 done: true,
768 messages: vec![
769 proto::ChannelMessage {
770 id: 8,
771 body: "y".into(),
772 timestamp: 998,
773 sender_id: 5,
774 nonce: Some(4.into()),
775 },
776 proto::ChannelMessage {
777 id: 9,
778 body: "z".into(),
779 timestamp: 999,
780 sender_id: 6,
781 nonce: Some(5.into()),
782 },
783 ],
784 },
785 )
786 .await;
787
788 assert_eq!(
789 channel.next_event(&cx).await,
790 ChannelEvent::MessagesUpdated {
791 old_range: 0..0,
792 new_count: 2,
793 }
794 );
795 channel.read_with(&cx, |channel, _| {
796 assert_eq!(
797 channel
798 .messages_in_range(0..2)
799 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
800 .collect::<Vec<_>>(),
801 &[
802 ("nathansobo".into(), "y".into()),
803 ("maxbrunsfeld".into(), "z".into())
804 ]
805 );
806 });
807 }
808}