1use crate::{
2 rpc::{self, Client},
3 user::{User, UserStore},
4 util::{post_inc, TryFutureExt},
5};
6use anyhow::{anyhow, Context, Result};
7use gpui::{
8 sum_tree::{self, Bias, SumTree},
9 Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
10};
11use postage::prelude::Stream;
12use std::{
13 collections::{HashMap, HashSet},
14 mem,
15 ops::Range,
16 sync::Arc,
17};
18use time::OffsetDateTime;
19use zrpc::{
20 proto::{self, ChannelMessageSent},
21 TypedEnvelope,
22};
23
24pub struct ChannelList {
25 available_channels: Option<Vec<ChannelDetails>>,
26 channels: HashMap<u64, WeakModelHandle<Channel>>,
27 rpc: Arc<Client>,
28 user_store: Arc<UserStore>,
29 _task: Task<Option<()>>,
30}
31
32#[derive(Clone, Debug, PartialEq)]
33pub struct ChannelDetails {
34 pub id: u64,
35 pub name: String,
36}
37
38pub struct Channel {
39 details: ChannelDetails,
40 messages: SumTree<ChannelMessage>,
41 loaded_all_messages: bool,
42 next_pending_message_id: usize,
43 user_store: Arc<UserStore>,
44 rpc: Arc<Client>,
45 _subscription: rpc::Subscription,
46}
47
48#[derive(Clone, Debug)]
49pub struct ChannelMessage {
50 pub id: ChannelMessageId,
51 pub body: String,
52 pub timestamp: OffsetDateTime,
53 pub sender: Arc<User>,
54}
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
57pub enum ChannelMessageId {
58 Saved(u64),
59 Pending(usize),
60}
61
62#[derive(Clone, Debug, Default)]
63pub struct ChannelMessageSummary {
64 max_id: ChannelMessageId,
65 count: usize,
66}
67
68#[derive(Copy, Clone, Debug, Default)]
69struct Count(usize);
70
71pub enum ChannelListEvent {}
72
73#[derive(Clone, Debug, PartialEq)]
74pub enum ChannelEvent {
75 MessagesUpdated {
76 old_range: Range<usize>,
77 new_count: usize,
78 },
79}
80
81impl Entity for ChannelList {
82 type Event = ChannelListEvent;
83}
84
85impl ChannelList {
86 pub fn new(
87 user_store: Arc<UserStore>,
88 rpc: Arc<rpc::Client>,
89 cx: &mut ModelContext<Self>,
90 ) -> Self {
91 let _task = cx.spawn_weak(|this, mut cx| {
92 let rpc = rpc.clone();
93 async move {
94 let mut status = rpc.status();
95 while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
96 match status {
97 rpc::Status::Connected { .. } => {
98 let response = rpc
99 .request(proto::GetChannels {})
100 .await
101 .context("failed to fetch available channels")?;
102 this.update(&mut cx, |this, cx| {
103 this.available_channels =
104 Some(response.channels.into_iter().map(Into::into).collect());
105
106 let mut to_remove = Vec::new();
107 for (channel_id, channel) in &this.channels {
108 if let Some(channel) = channel.upgrade(cx) {
109 channel.update(cx, |channel, cx| channel.rejoin(cx))
110 } else {
111 to_remove.push(*channel_id);
112 }
113 }
114
115 for channel_id in to_remove {
116 this.channels.remove(&channel_id);
117 }
118 cx.notify();
119 });
120 }
121 rpc::Status::SignedOut { .. } => {
122 this.update(&mut cx, |this, cx| {
123 this.available_channels = None;
124 this.channels.clear();
125 cx.notify();
126 });
127 }
128 _ => {}
129 }
130 }
131 Ok(())
132 }
133 .log_err()
134 });
135
136 Self {
137 available_channels: None,
138 channels: Default::default(),
139 user_store,
140 rpc,
141 _task,
142 }
143 }
144
145 pub fn available_channels(&self) -> Option<&[ChannelDetails]> {
146 self.available_channels.as_ref().map(Vec::as_slice)
147 }
148
149 pub fn get_channel(
150 &mut self,
151 id: u64,
152 cx: &mut MutableAppContext,
153 ) -> Option<ModelHandle<Channel>> {
154 if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) {
155 return Some(channel);
156 }
157
158 let channels = self.available_channels.as_ref()?;
159 let details = channels.iter().find(|details| details.id == id)?.clone();
160 let channel =
161 cx.add_model(|cx| Channel::new(details, self.user_store.clone(), self.rpc.clone(), cx));
162 self.channels.insert(id, channel.downgrade());
163 Some(channel)
164 }
165}
166
167impl Entity for Channel {
168 type Event = ChannelEvent;
169
170 fn release(&mut self, cx: &mut MutableAppContext) {
171 let rpc = self.rpc.clone();
172 let channel_id = self.details.id;
173 cx.foreground()
174 .spawn(async move {
175 if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await {
176 log::error!("error leaving channel: {}", error);
177 };
178 })
179 .detach()
180 }
181}
182
183impl Channel {
184 pub fn new(
185 details: ChannelDetails,
186 user_store: Arc<UserStore>,
187 rpc: Arc<Client>,
188 cx: &mut ModelContext<Self>,
189 ) -> Self {
190 let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent);
191
192 {
193 let user_store = user_store.clone();
194 let rpc = rpc.clone();
195 let channel_id = details.id;
196 cx.spawn(|channel, mut cx| {
197 async move {
198 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
199 let messages = messages_from_proto(response.messages, &user_store).await?;
200 let loaded_all_messages = response.done;
201
202 channel.update(&mut cx, |channel, cx| {
203 channel.insert_messages(messages, cx);
204 channel.loaded_all_messages = loaded_all_messages;
205 });
206
207 Ok(())
208 }
209 .log_err()
210 })
211 .detach();
212 }
213
214 Self {
215 details,
216 user_store,
217 rpc,
218 messages: Default::default(),
219 loaded_all_messages: false,
220 next_pending_message_id: 0,
221 _subscription,
222 }
223 }
224
225 pub fn name(&self) -> &str {
226 &self.details.name
227 }
228
229 pub fn send_message(
230 &mut self,
231 body: String,
232 cx: &mut ModelContext<Self>,
233 ) -> Result<Task<Result<()>>> {
234 if body.is_empty() {
235 Err(anyhow!("message body can't be empty"))?;
236 }
237
238 let current_user = self
239 .user_store
240 .current_user()
241 .ok_or_else(|| anyhow!("current_user is not present"))?;
242
243 let channel_id = self.details.id;
244 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
245 self.insert_messages(
246 SumTree::from_item(
247 ChannelMessage {
248 id: pending_id,
249 body: body.clone(),
250 sender: current_user,
251 timestamp: OffsetDateTime::now_utc(),
252 },
253 &(),
254 ),
255 cx,
256 );
257 let user_store = self.user_store.clone();
258 let rpc = self.rpc.clone();
259 Ok(cx.spawn(|this, mut cx| async move {
260 let request = rpc.request(proto::SendChannelMessage { channel_id, body });
261 let response = request.await?;
262 let message = ChannelMessage::from_proto(
263 response.message.ok_or_else(|| anyhow!("invalid message"))?,
264 &user_store,
265 )
266 .await?;
267 this.update(&mut cx, |this, cx| {
268 this.remove_message(pending_id, cx);
269 this.insert_messages(SumTree::from_item(message, &()), cx);
270 Ok(())
271 })
272 }))
273 }
274
275 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
276 if !self.loaded_all_messages {
277 let rpc = self.rpc.clone();
278 let user_store = self.user_store.clone();
279 let channel_id = self.details.id;
280 if let Some(before_message_id) =
281 self.messages.first().and_then(|message| match message.id {
282 ChannelMessageId::Saved(id) => Some(id),
283 ChannelMessageId::Pending(_) => None,
284 })
285 {
286 cx.spawn(|this, mut cx| {
287 async move {
288 let response = rpc
289 .request(proto::GetChannelMessages {
290 channel_id,
291 before_message_id,
292 })
293 .await?;
294 let loaded_all_messages = response.done;
295 let messages = messages_from_proto(response.messages, &user_store).await?;
296 this.update(&mut cx, |this, cx| {
297 this.loaded_all_messages = loaded_all_messages;
298 this.insert_messages(messages, cx);
299 });
300 Ok(())
301 }
302 .log_err()
303 })
304 .detach();
305 return true;
306 }
307 }
308 false
309 }
310
311 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
312 let user_store = self.user_store.clone();
313 let rpc = self.rpc.clone();
314 let channel_id = self.details.id;
315 cx.spawn(|channel, mut cx| {
316 async move {
317 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
318 let messages = messages_from_proto(response.messages, &user_store).await?;
319 let loaded_all_messages = response.done;
320
321 channel.update(&mut cx, |channel, cx| {
322 if let Some((first_new_message, last_old_message)) =
323 messages.first().zip(channel.messages.last())
324 {
325 if first_new_message.id > last_old_message.id {
326 let old_messages = mem::take(&mut channel.messages);
327 cx.emit(ChannelEvent::MessagesUpdated {
328 old_range: 0..old_messages.summary().count,
329 new_count: 0,
330 });
331 channel.loaded_all_messages = loaded_all_messages;
332 }
333 }
334
335 channel.insert_messages(messages, cx);
336 if loaded_all_messages {
337 channel.loaded_all_messages = loaded_all_messages;
338 }
339 });
340
341 Ok(())
342 }
343 .log_err()
344 })
345 .detach();
346 }
347
348 pub fn message_count(&self) -> usize {
349 self.messages.summary().count
350 }
351
352 pub fn messages(&self) -> &SumTree<ChannelMessage> {
353 &self.messages
354 }
355
356 pub fn message(&self, ix: usize) -> &ChannelMessage {
357 let mut cursor = self.messages.cursor::<Count, ()>();
358 cursor.seek(&Count(ix), Bias::Right, &());
359 cursor.item().unwrap()
360 }
361
362 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
363 let mut cursor = self.messages.cursor::<Count, ()>();
364 cursor.seek(&Count(range.start), Bias::Right, &());
365 cursor.take(range.len())
366 }
367
368 fn handle_message_sent(
369 &mut self,
370 message: TypedEnvelope<ChannelMessageSent>,
371 _: Arc<rpc::Client>,
372 cx: &mut ModelContext<Self>,
373 ) -> Result<()> {
374 let user_store = self.user_store.clone();
375 let message = message
376 .payload
377 .message
378 .ok_or_else(|| anyhow!("empty message"))?;
379
380 cx.spawn(|this, mut cx| {
381 async move {
382 let message = ChannelMessage::from_proto(message, &user_store).await?;
383 this.update(&mut cx, |this, cx| {
384 this.insert_messages(SumTree::from_item(message, &()), cx)
385 });
386 Ok(())
387 }
388 .log_err()
389 })
390 .detach();
391 Ok(())
392 }
393
394 fn remove_message(&mut self, message_id: ChannelMessageId, cx: &mut ModelContext<Self>) {
395 let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
396 let mut new_messages = old_cursor.slice(&message_id, Bias::Left, &());
397 let start_ix = old_cursor.sum_start().0;
398 let removed_messages = old_cursor.slice(&message_id, Bias::Right, &());
399 let removed_count = removed_messages.summary().count;
400 new_messages.push_tree(old_cursor.suffix(&()), &());
401
402 drop(old_cursor);
403 self.messages = new_messages;
404
405 if removed_count > 0 {
406 let end_ix = start_ix + removed_count;
407 cx.emit(ChannelEvent::MessagesUpdated {
408 old_range: start_ix..end_ix,
409 new_count: 0,
410 });
411 cx.notify();
412 }
413 }
414
415 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
416 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
417 let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
418 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
419 let start_ix = old_cursor.sum_start().0;
420 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
421 let removed_count = removed_messages.summary().count;
422 let new_count = messages.summary().count;
423 let end_ix = start_ix + removed_count;
424
425 new_messages.push_tree(messages, &());
426 new_messages.push_tree(old_cursor.suffix(&()), &());
427 drop(old_cursor);
428 self.messages = new_messages;
429
430 cx.emit(ChannelEvent::MessagesUpdated {
431 old_range: start_ix..end_ix,
432 new_count,
433 });
434 cx.notify();
435 }
436 }
437}
438
439async fn messages_from_proto(
440 proto_messages: Vec<proto::ChannelMessage>,
441 user_store: &UserStore,
442) -> Result<SumTree<ChannelMessage>> {
443 let unique_user_ids = proto_messages
444 .iter()
445 .map(|m| m.sender_id)
446 .collect::<HashSet<_>>()
447 .into_iter()
448 .collect();
449 user_store.load_users(unique_user_ids).await?;
450
451 let mut messages = Vec::with_capacity(proto_messages.len());
452 for message in proto_messages {
453 messages.push(ChannelMessage::from_proto(message, &user_store).await?);
454 }
455 let mut result = SumTree::new();
456 result.extend(messages, &());
457 Ok(result)
458}
459
460impl From<proto::Channel> for ChannelDetails {
461 fn from(message: proto::Channel) -> Self {
462 Self {
463 id: message.id,
464 name: message.name,
465 }
466 }
467}
468
469impl ChannelMessage {
470 pub async fn from_proto(
471 message: proto::ChannelMessage,
472 user_store: &UserStore,
473 ) -> Result<Self> {
474 let sender = user_store.fetch_user(message.sender_id).await?;
475 Ok(ChannelMessage {
476 id: ChannelMessageId::Saved(message.id),
477 body: message.body,
478 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
479 sender,
480 })
481 }
482
483 pub fn is_pending(&self) -> bool {
484 matches!(self.id, ChannelMessageId::Pending(_))
485 }
486}
487
488impl sum_tree::Item for ChannelMessage {
489 type Summary = ChannelMessageSummary;
490
491 fn summary(&self) -> Self::Summary {
492 ChannelMessageSummary {
493 max_id: self.id,
494 count: 1,
495 }
496 }
497}
498
499impl Default for ChannelMessageId {
500 fn default() -> Self {
501 Self::Saved(0)
502 }
503}
504
505impl sum_tree::Summary for ChannelMessageSummary {
506 type Context = ();
507
508 fn add_summary(&mut self, summary: &Self, _: &()) {
509 self.max_id = summary.max_id;
510 self.count += summary.count;
511 }
512}
513
514impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
515 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
516 debug_assert!(summary.max_id > *self);
517 *self = summary.max_id;
518 }
519}
520
521impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
522 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
523 self.0 += summary.count;
524 }
525}
526
527impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
528 fn cmp(&self, other: &Self, _: &()) -> std::cmp::Ordering {
529 Ord::cmp(&self.0, &other.0)
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::test::{FakeHttpClient, FakeServer};
537 use gpui::TestAppContext;
538 use surf::http::Response;
539
540 #[gpui::test]
541 async fn test_channel_messages(mut cx: TestAppContext) {
542 let user_id = 5;
543 let mut client = Client::new();
544 let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
545 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
546 let user_store = UserStore::new(client.clone(), http_client, cx.background().as_ref());
547
548 let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
549 channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
550
551 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
552 assert_eq!(get_users.payload.user_ids, vec![5]);
553 server
554 .respond(
555 get_users.receipt(),
556 proto::GetUsersResponse {
557 users: vec![proto::User {
558 id: 5,
559 github_login: "nathansobo".into(),
560 avatar_url: "http://avatar.com/nathansobo".into(),
561 }],
562 },
563 )
564 .await;
565
566 // Get the available channels.
567 let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
568 server
569 .respond(
570 get_channels.receipt(),
571 proto::GetChannelsResponse {
572 channels: vec![proto::Channel {
573 id: 5,
574 name: "the-channel".to_string(),
575 }],
576 },
577 )
578 .await;
579 channel_list.next_notification(&cx).await;
580 channel_list.read_with(&cx, |list, _| {
581 assert_eq!(
582 list.available_channels().unwrap(),
583 &[ChannelDetails {
584 id: 5,
585 name: "the-channel".into(),
586 }]
587 )
588 });
589
590 // Join a channel and populate its existing messages.
591 let channel = channel_list
592 .update(&mut cx, |list, cx| {
593 let channel_id = list.available_channels().unwrap()[0].id;
594 list.get_channel(channel_id, cx)
595 })
596 .unwrap();
597 channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
598 let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
599 server
600 .respond(
601 join_channel.receipt(),
602 proto::JoinChannelResponse {
603 messages: vec![
604 proto::ChannelMessage {
605 id: 10,
606 body: "a".into(),
607 timestamp: 1000,
608 sender_id: 5,
609 },
610 proto::ChannelMessage {
611 id: 11,
612 body: "b".into(),
613 timestamp: 1001,
614 sender_id: 6,
615 },
616 ],
617 done: false,
618 },
619 )
620 .await;
621
622 // Client requests all users for the received messages
623 let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
624 get_users.payload.user_ids.sort();
625 assert_eq!(get_users.payload.user_ids, vec![6]);
626 server
627 .respond(
628 get_users.receipt(),
629 proto::GetUsersResponse {
630 users: vec![proto::User {
631 id: 6,
632 github_login: "maxbrunsfeld".into(),
633 avatar_url: "http://avatar.com/maxbrunsfeld".into(),
634 }],
635 },
636 )
637 .await;
638
639 assert_eq!(
640 channel.next_event(&cx).await,
641 ChannelEvent::MessagesUpdated {
642 old_range: 0..0,
643 new_count: 2,
644 }
645 );
646 channel.read_with(&cx, |channel, _| {
647 assert_eq!(
648 channel
649 .messages_in_range(0..2)
650 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
651 .collect::<Vec<_>>(),
652 &[
653 ("nathansobo".into(), "a".into()),
654 ("maxbrunsfeld".into(), "b".into())
655 ]
656 );
657 });
658
659 // Receive a new message.
660 server
661 .send(proto::ChannelMessageSent {
662 channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
663 message: Some(proto::ChannelMessage {
664 id: 12,
665 body: "c".into(),
666 timestamp: 1002,
667 sender_id: 7,
668 }),
669 })
670 .await;
671
672 // Client requests user for message since they haven't seen them yet
673 let get_users = server.receive::<proto::GetUsers>().await.unwrap();
674 assert_eq!(get_users.payload.user_ids, vec![7]);
675 server
676 .respond(
677 get_users.receipt(),
678 proto::GetUsersResponse {
679 users: vec![proto::User {
680 id: 7,
681 github_login: "as-cii".into(),
682 avatar_url: "http://avatar.com/as-cii".into(),
683 }],
684 },
685 )
686 .await;
687
688 assert_eq!(
689 channel.next_event(&cx).await,
690 ChannelEvent::MessagesUpdated {
691 old_range: 2..2,
692 new_count: 1,
693 }
694 );
695 channel.read_with(&cx, |channel, _| {
696 assert_eq!(
697 channel
698 .messages_in_range(2..3)
699 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
700 .collect::<Vec<_>>(),
701 &[("as-cii".into(), "c".into())]
702 )
703 });
704
705 // Scroll up to view older messages.
706 channel.update(&mut cx, |channel, cx| {
707 assert!(channel.load_more_messages(cx));
708 });
709 let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
710 assert_eq!(get_messages.payload.channel_id, 5);
711 assert_eq!(get_messages.payload.before_message_id, 10);
712 server
713 .respond(
714 get_messages.receipt(),
715 proto::GetChannelMessagesResponse {
716 done: true,
717 messages: vec![
718 proto::ChannelMessage {
719 id: 8,
720 body: "y".into(),
721 timestamp: 998,
722 sender_id: 5,
723 },
724 proto::ChannelMessage {
725 id: 9,
726 body: "z".into(),
727 timestamp: 999,
728 sender_id: 6,
729 },
730 ],
731 },
732 )
733 .await;
734
735 assert_eq!(
736 channel.next_event(&cx).await,
737 ChannelEvent::MessagesUpdated {
738 old_range: 0..0,
739 new_count: 2,
740 }
741 );
742 channel.read_with(&cx, |channel, _| {
743 assert_eq!(
744 channel
745 .messages_in_range(0..2)
746 .map(|message| (message.sender.github_login.clone(), message.body.clone()))
747 .collect::<Vec<_>>(),
748 &[
749 ("nathansobo".into(), "y".into()),
750 ("maxbrunsfeld".into(), "z".into())
751 ]
752 );
753 });
754 }
755}