1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope, UserId,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use rand::prelude::*;
11use std::{
12 collections::HashSet,
13 mem,
14 ops::{ControlFlow, Range},
15 sync::Arc,
16};
17use sum_tree::{Bias, SumTree};
18use time::OffsetDateTime;
19use util::{post_inc, ResultExt as _, TryFutureExt};
20
21pub struct ChannelChat {
22 pub channel_id: ChannelId,
23 messages: SumTree<ChannelMessage>,
24 acknowledged_message_ids: HashSet<u64>,
25 channel_store: ModelHandle<ChannelStore>,
26 loaded_all_messages: bool,
27 last_acknowledged_id: Option<u64>,
28 next_pending_message_id: usize,
29 user_store: ModelHandle<UserStore>,
30 rpc: Arc<Client>,
31 outgoing_messages_lock: Arc<Mutex<()>>,
32 rng: StdRng,
33 _subscription: Subscription,
34}
35
36#[derive(Debug, PartialEq, Eq)]
37pub struct MessageParams {
38 pub text: String,
39 pub mentions: Vec<(Range<usize>, UserId)>,
40}
41
42#[derive(Clone, Debug)]
43pub struct ChannelMessage {
44 pub id: ChannelMessageId,
45 pub body: String,
46 pub timestamp: OffsetDateTime,
47 pub sender: Arc<User>,
48 pub nonce: u128,
49 pub mentions: Vec<(Range<usize>, UserId)>,
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
53pub enum ChannelMessageId {
54 Saved(u64),
55 Pending(usize),
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct ChannelMessageSummary {
60 max_id: ChannelMessageId,
61 count: usize,
62}
63
64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Debug, PartialEq)]
68pub enum ChannelChatEvent {
69 MessagesUpdated {
70 old_range: Range<usize>,
71 new_count: usize,
72 },
73 NewMessage {
74 channel_id: ChannelId,
75 message_id: u64,
76 },
77}
78
79pub fn init(client: &Arc<Client>) {
80 client.add_model_message_handler(ChannelChat::handle_message_sent);
81 client.add_model_message_handler(ChannelChat::handle_message_removed);
82}
83
84impl Entity for ChannelChat {
85 type Event = ChannelChatEvent;
86
87 fn release(&mut self, _: &mut AppContext) {
88 self.rpc
89 .send(proto::LeaveChannelChat {
90 channel_id: self.channel_id,
91 })
92 .log_err();
93 }
94}
95
96impl ChannelChat {
97 pub async fn new(
98 channel: Arc<Channel>,
99 channel_store: ModelHandle<ChannelStore>,
100 user_store: ModelHandle<UserStore>,
101 client: Arc<Client>,
102 mut cx: AsyncAppContext,
103 ) -> Result<ModelHandle<Self>> {
104 let channel_id = channel.id;
105 let subscription = client.subscribe_to_entity(channel_id).unwrap();
106
107 let response = client
108 .request(proto::JoinChannelChat { channel_id })
109 .await?;
110 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
111 let loaded_all_messages = response.done;
112
113 Ok(cx.add_model(|cx| {
114 let mut this = Self {
115 channel_id: channel.id,
116 user_store,
117 channel_store,
118 rpc: client,
119 outgoing_messages_lock: Default::default(),
120 messages: Default::default(),
121 acknowledged_message_ids: Default::default(),
122 loaded_all_messages,
123 next_pending_message_id: 0,
124 last_acknowledged_id: None,
125 rng: StdRng::from_entropy(),
126 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
127 };
128 this.insert_messages(messages, cx);
129 this
130 }))
131 }
132
133 pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
134 self.channel_store
135 .read(cx)
136 .channel_for_id(self.channel_id)
137 .cloned()
138 }
139
140 pub fn client(&self) -> &Arc<Client> {
141 &self.rpc
142 }
143
144 pub fn send_message(
145 &mut self,
146 message: MessageParams,
147 cx: &mut ModelContext<Self>,
148 ) -> Result<Task<Result<u64>>> {
149 if message.text.is_empty() {
150 Err(anyhow!("message body can't be empty"))?;
151 }
152
153 let current_user = self
154 .user_store
155 .read(cx)
156 .current_user()
157 .ok_or_else(|| anyhow!("current_user is not present"))?;
158
159 let channel_id = self.channel_id;
160 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
161 let nonce = self.rng.gen();
162 self.insert_messages(
163 SumTree::from_item(
164 ChannelMessage {
165 id: pending_id,
166 body: message.text.clone(),
167 sender: current_user,
168 timestamp: OffsetDateTime::now_utc(),
169 mentions: message.mentions.clone(),
170 nonce,
171 },
172 &(),
173 ),
174 cx,
175 );
176 let user_store = self.user_store.clone();
177 let rpc = self.rpc.clone();
178 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
179 Ok(cx.spawn(|this, mut cx| async move {
180 let outgoing_message_guard = outgoing_messages_lock.lock().await;
181 let request = rpc.request(proto::SendChannelMessage {
182 channel_id,
183 body: message.text,
184 nonce: Some(nonce.into()),
185 mentions: mentions_to_proto(&message.mentions),
186 });
187 let response = request.await?;
188 drop(outgoing_message_guard);
189 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
190 let id = response.id;
191 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
192 this.update(&mut cx, |this, cx| {
193 this.insert_messages(SumTree::from_item(message, &()), cx);
194 Ok(id)
195 })
196 }))
197 }
198
199 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
200 let response = self.rpc.request(proto::RemoveChannelMessage {
201 channel_id: self.channel_id,
202 message_id: id,
203 });
204 cx.spawn(|this, mut cx| async move {
205 response.await?;
206
207 this.update(&mut cx, |this, cx| {
208 this.message_removed(id, cx);
209 Ok(())
210 })
211 })
212 }
213
214 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
215 if self.loaded_all_messages {
216 return None;
217 }
218
219 let rpc = self.rpc.clone();
220 let user_store = self.user_store.clone();
221 let channel_id = self.channel_id;
222 let before_message_id = self.first_loaded_message_id()?;
223 Some(cx.spawn(|this, mut cx| {
224 async move {
225 let response = rpc
226 .request(proto::GetChannelMessages {
227 channel_id,
228 before_message_id,
229 })
230 .await?;
231 let loaded_all_messages = response.done;
232 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
233 this.update(&mut cx, |this, cx| {
234 this.loaded_all_messages = loaded_all_messages;
235 this.insert_messages(messages, cx);
236 });
237 anyhow::Ok(())
238 }
239 .log_err()
240 }))
241 }
242
243 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
244 self.messages.first().and_then(|message| match message.id {
245 ChannelMessageId::Saved(id) => Some(id),
246 ChannelMessageId::Pending(_) => None,
247 })
248 }
249
250 /// Load all of the chat messages since a certain message id.
251 ///
252 /// For now, we always maintain a suffix of the channel's messages.
253 pub async fn load_history_since_message(
254 chat: ModelHandle<Self>,
255 message_id: u64,
256 mut cx: AsyncAppContext,
257 ) -> Option<usize> {
258 loop {
259 let step = chat.update(&mut cx, |chat, cx| {
260 if let Some(first_id) = chat.first_loaded_message_id() {
261 if first_id <= message_id {
262 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
263 let message_id = ChannelMessageId::Saved(message_id);
264 cursor.seek(&message_id, Bias::Left, &());
265 return ControlFlow::Break(
266 if cursor
267 .item()
268 .map_or(false, |message| message.id == message_id)
269 {
270 Some(cursor.start().1 .0)
271 } else {
272 None
273 },
274 );
275 }
276 }
277 ControlFlow::Continue(chat.load_more_messages(cx))
278 });
279 match step {
280 ControlFlow::Break(ix) => return ix,
281 ControlFlow::Continue(task) => task?.await?,
282 }
283 }
284 }
285
286 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
287 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
288 if self
289 .last_acknowledged_id
290 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
291 {
292 self.rpc
293 .send(proto::AckChannelMessage {
294 channel_id: self.channel_id,
295 message_id: latest_message_id,
296 })
297 .ok();
298 self.last_acknowledged_id = Some(latest_message_id);
299 self.channel_store.update(cx, |store, cx| {
300 store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
301 });
302 }
303 }
304 }
305
306 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
307 let user_store = self.user_store.clone();
308 let rpc = self.rpc.clone();
309 let channel_id = self.channel_id;
310 cx.spawn(|this, mut cx| {
311 async move {
312 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
313 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
314 let loaded_all_messages = response.done;
315
316 let pending_messages = this.update(&mut cx, |this, cx| {
317 if let Some((first_new_message, last_old_message)) =
318 messages.first().zip(this.messages.last())
319 {
320 if first_new_message.id > last_old_message.id {
321 let old_messages = mem::take(&mut this.messages);
322 cx.emit(ChannelChatEvent::MessagesUpdated {
323 old_range: 0..old_messages.summary().count,
324 new_count: 0,
325 });
326 this.loaded_all_messages = loaded_all_messages;
327 }
328 }
329
330 this.insert_messages(messages, cx);
331 if loaded_all_messages {
332 this.loaded_all_messages = loaded_all_messages;
333 }
334
335 this.pending_messages().cloned().collect::<Vec<_>>()
336 });
337
338 for pending_message in pending_messages {
339 let request = rpc.request(proto::SendChannelMessage {
340 channel_id,
341 body: pending_message.body,
342 mentions: mentions_to_proto(&pending_message.mentions),
343 nonce: Some(pending_message.nonce.into()),
344 });
345 let response = request.await?;
346 let message = ChannelMessage::from_proto(
347 response.message.ok_or_else(|| anyhow!("invalid message"))?,
348 &user_store,
349 &mut cx,
350 )
351 .await?;
352 this.update(&mut cx, |this, cx| {
353 this.insert_messages(SumTree::from_item(message, &()), cx);
354 });
355 }
356
357 anyhow::Ok(())
358 }
359 .log_err()
360 })
361 .detach();
362 }
363
364 pub fn message_count(&self) -> usize {
365 self.messages.summary().count
366 }
367
368 pub fn messages(&self) -> &SumTree<ChannelMessage> {
369 &self.messages
370 }
371
372 pub fn message(&self, ix: usize) -> &ChannelMessage {
373 let mut cursor = self.messages.cursor::<Count>();
374 cursor.seek(&Count(ix), Bias::Right, &());
375 cursor.item().unwrap()
376 }
377
378 pub fn acknowledge_message(&mut self, id: u64) {
379 if self.acknowledged_message_ids.insert(id) {
380 self.rpc
381 .send(proto::AckChannelMessage {
382 channel_id: self.channel_id,
383 message_id: id,
384 })
385 .ok();
386 }
387 }
388
389 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
390 let mut cursor = self.messages.cursor::<Count>();
391 cursor.seek(&Count(range.start), Bias::Right, &());
392 cursor.take(range.len())
393 }
394
395 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
396 let mut cursor = self.messages.cursor::<ChannelMessageId>();
397 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
398 cursor
399 }
400
401 async fn handle_message_sent(
402 this: ModelHandle<Self>,
403 message: TypedEnvelope<proto::ChannelMessageSent>,
404 _: Arc<Client>,
405 mut cx: AsyncAppContext,
406 ) -> Result<()> {
407 let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
408 let message = message
409 .payload
410 .message
411 .ok_or_else(|| anyhow!("empty message"))?;
412 let message_id = message.id;
413
414 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
415 this.update(&mut cx, |this, cx| {
416 this.insert_messages(SumTree::from_item(message, &()), cx);
417 cx.emit(ChannelChatEvent::NewMessage {
418 channel_id: this.channel_id,
419 message_id,
420 })
421 });
422
423 Ok(())
424 }
425
426 async fn handle_message_removed(
427 this: ModelHandle<Self>,
428 message: TypedEnvelope<proto::RemoveChannelMessage>,
429 _: Arc<Client>,
430 mut cx: AsyncAppContext,
431 ) -> Result<()> {
432 this.update(&mut cx, |this, cx| {
433 this.message_removed(message.payload.message_id, cx)
434 });
435 Ok(())
436 }
437
438 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
439 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
440 let nonces = messages
441 .cursor::<()>()
442 .map(|m| m.nonce)
443 .collect::<HashSet<_>>();
444
445 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
446 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
447 let start_ix = old_cursor.start().1 .0;
448 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
449 let removed_count = removed_messages.summary().count;
450 let new_count = messages.summary().count;
451 let end_ix = start_ix + removed_count;
452
453 new_messages.append(messages, &());
454
455 let mut ranges = Vec::<Range<usize>>::new();
456 if new_messages.last().unwrap().is_pending() {
457 new_messages.append(old_cursor.suffix(&()), &());
458 } else {
459 new_messages.append(
460 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
461 &(),
462 );
463
464 while let Some(message) = old_cursor.item() {
465 let message_ix = old_cursor.start().1 .0;
466 if nonces.contains(&message.nonce) {
467 if ranges.last().map_or(false, |r| r.end == message_ix) {
468 ranges.last_mut().unwrap().end += 1;
469 } else {
470 ranges.push(message_ix..message_ix + 1);
471 }
472 } else {
473 new_messages.push(message.clone(), &());
474 }
475 old_cursor.next(&());
476 }
477 }
478
479 drop(old_cursor);
480 self.messages = new_messages;
481
482 for range in ranges.into_iter().rev() {
483 cx.emit(ChannelChatEvent::MessagesUpdated {
484 old_range: range,
485 new_count: 0,
486 });
487 }
488 cx.emit(ChannelChatEvent::MessagesUpdated {
489 old_range: start_ix..end_ix,
490 new_count,
491 });
492
493 cx.notify();
494 }
495 }
496
497 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
498 let mut cursor = self.messages.cursor::<ChannelMessageId>();
499 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
500 if let Some(item) = cursor.item() {
501 if item.id == ChannelMessageId::Saved(id) {
502 let ix = messages.summary().count;
503 cursor.next(&());
504 messages.append(cursor.suffix(&()), &());
505 drop(cursor);
506 self.messages = messages;
507 cx.emit(ChannelChatEvent::MessagesUpdated {
508 old_range: ix..ix + 1,
509 new_count: 0,
510 });
511 }
512 }
513 }
514}
515
516async fn messages_from_proto(
517 proto_messages: Vec<proto::ChannelMessage>,
518 user_store: &ModelHandle<UserStore>,
519 cx: &mut AsyncAppContext,
520) -> Result<SumTree<ChannelMessage>> {
521 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
522 let mut result = SumTree::new();
523 result.extend(messages, &());
524 Ok(result)
525}
526
527impl ChannelMessage {
528 pub async fn from_proto(
529 message: proto::ChannelMessage,
530 user_store: &ModelHandle<UserStore>,
531 cx: &mut AsyncAppContext,
532 ) -> Result<Self> {
533 let sender = user_store
534 .update(cx, |user_store, cx| {
535 user_store.get_user(message.sender_id, cx)
536 })
537 .await?;
538 Ok(ChannelMessage {
539 id: ChannelMessageId::Saved(message.id),
540 body: message.body,
541 mentions: message
542 .mentions
543 .into_iter()
544 .filter_map(|mention| {
545 let range = mention.range?;
546 Some((range.start as usize..range.end as usize, mention.user_id))
547 })
548 .collect(),
549 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
550 sender,
551 nonce: message
552 .nonce
553 .ok_or_else(|| anyhow!("nonce is required"))?
554 .into(),
555 })
556 }
557
558 pub fn is_pending(&self) -> bool {
559 matches!(self.id, ChannelMessageId::Pending(_))
560 }
561
562 pub async fn from_proto_vec(
563 proto_messages: Vec<proto::ChannelMessage>,
564 user_store: &ModelHandle<UserStore>,
565 cx: &mut AsyncAppContext,
566 ) -> Result<Vec<Self>> {
567 let unique_user_ids = proto_messages
568 .iter()
569 .map(|m| m.sender_id)
570 .collect::<HashSet<_>>()
571 .into_iter()
572 .collect();
573 user_store
574 .update(cx, |user_store, cx| {
575 user_store.get_users(unique_user_ids, cx)
576 })
577 .await?;
578
579 let mut messages = Vec::with_capacity(proto_messages.len());
580 for message in proto_messages {
581 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
582 }
583 Ok(messages)
584 }
585}
586
587pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
588 mentions
589 .iter()
590 .map(|(range, user_id)| proto::ChatMention {
591 range: Some(proto::Range {
592 start: range.start as u64,
593 end: range.end as u64,
594 }),
595 user_id: *user_id as u64,
596 })
597 .collect()
598}
599
600impl sum_tree::Item for ChannelMessage {
601 type Summary = ChannelMessageSummary;
602
603 fn summary(&self) -> Self::Summary {
604 ChannelMessageSummary {
605 max_id: self.id,
606 count: 1,
607 }
608 }
609}
610
611impl Default for ChannelMessageId {
612 fn default() -> Self {
613 Self::Saved(0)
614 }
615}
616
617impl sum_tree::Summary for ChannelMessageSummary {
618 type Context = ();
619
620 fn add_summary(&mut self, summary: &Self, _: &()) {
621 self.max_id = summary.max_id;
622 self.count += summary.count;
623 }
624}
625
626impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
627 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
628 debug_assert!(summary.max_id > *self);
629 *self = summary.max_id;
630 }
631}
632
633impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
634 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
635 self.0 += summary.count;
636 }
637}
638
639impl<'a> From<&'a str> for MessageParams {
640 fn from(value: &'a str) -> Self {
641 Self {
642 text: value.into(),
643 mentions: Vec::new(),
644 }
645 }
646}