1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope, UserId,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use rand::prelude::*;
11use std::{
12 collections::HashSet,
13 mem,
14 ops::{ControlFlow, Range},
15 sync::Arc,
16};
17use sum_tree::{Bias, SumTree};
18use time::OffsetDateTime;
19use util::{post_inc, ResultExt as _, TryFutureExt};
20
21pub struct ChannelChat {
22 channel: Arc<Channel>,
23 messages: SumTree<ChannelMessage>,
24 acknowledged_message_ids: HashSet<u64>,
25 channel_store: ModelHandle<ChannelStore>,
26 loaded_all_messages: bool,
27 last_acknowledged_id: Option<u64>,
28 next_pending_message_id: usize,
29 user_store: ModelHandle<UserStore>,
30 rpc: Arc<Client>,
31 outgoing_messages_lock: Arc<Mutex<()>>,
32 rng: StdRng,
33 _subscription: Subscription,
34}
35
36#[derive(Debug, PartialEq, Eq)]
37pub struct MessageParams {
38 pub text: String,
39 pub mentions: Vec<(Range<usize>, UserId)>,
40}
41
42#[derive(Clone, Debug)]
43pub struct ChannelMessage {
44 pub id: ChannelMessageId,
45 pub body: String,
46 pub timestamp: OffsetDateTime,
47 pub sender: Arc<User>,
48 pub nonce: u128,
49 pub mentions: Vec<(Range<usize>, UserId)>,
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
53pub enum ChannelMessageId {
54 Saved(u64),
55 Pending(usize),
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct ChannelMessageSummary {
60 max_id: ChannelMessageId,
61 count: usize,
62}
63
64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Debug, PartialEq)]
68pub enum ChannelChatEvent {
69 MessagesUpdated {
70 old_range: Range<usize>,
71 new_count: usize,
72 },
73 NewMessage {
74 channel_id: ChannelId,
75 message_id: u64,
76 },
77}
78
79pub fn init(client: &Arc<Client>) {
80 client.add_model_message_handler(ChannelChat::handle_message_sent);
81 client.add_model_message_handler(ChannelChat::handle_message_removed);
82}
83
84impl Entity for ChannelChat {
85 type Event = ChannelChatEvent;
86
87 fn release(&mut self, _: &mut AppContext) {
88 self.rpc
89 .send(proto::LeaveChannelChat {
90 channel_id: self.channel.id,
91 })
92 .log_err();
93 }
94}
95
96impl ChannelChat {
97 pub async fn new(
98 channel: Arc<Channel>,
99 channel_store: ModelHandle<ChannelStore>,
100 user_store: ModelHandle<UserStore>,
101 client: Arc<Client>,
102 mut cx: AsyncAppContext,
103 ) -> Result<ModelHandle<Self>> {
104 let channel_id = channel.id;
105 let subscription = client.subscribe_to_entity(channel_id).unwrap();
106
107 let response = client
108 .request(proto::JoinChannelChat { channel_id })
109 .await?;
110 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
111 let loaded_all_messages = response.done;
112
113 Ok(cx.add_model(|cx| {
114 let mut this = Self {
115 channel,
116 user_store,
117 channel_store,
118 rpc: client,
119 outgoing_messages_lock: Default::default(),
120 messages: Default::default(),
121 acknowledged_message_ids: Default::default(),
122 loaded_all_messages,
123 next_pending_message_id: 0,
124 last_acknowledged_id: None,
125 rng: StdRng::from_entropy(),
126 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
127 };
128 this.insert_messages(messages, cx);
129 this
130 }))
131 }
132
133 pub fn channel(&self) -> &Arc<Channel> {
134 &self.channel
135 }
136
137 pub fn client(&self) -> &Arc<Client> {
138 &self.rpc
139 }
140
141 pub fn send_message(
142 &mut self,
143 message: MessageParams,
144 cx: &mut ModelContext<Self>,
145 ) -> Result<Task<Result<u64>>> {
146 if message.text.is_empty() {
147 Err(anyhow!("message body can't be empty"))?;
148 }
149
150 let current_user = self
151 .user_store
152 .read(cx)
153 .current_user()
154 .ok_or_else(|| anyhow!("current_user is not present"))?;
155
156 let channel_id = self.channel.id;
157 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
158 let nonce = self.rng.gen();
159 self.insert_messages(
160 SumTree::from_item(
161 ChannelMessage {
162 id: pending_id,
163 body: message.text.clone(),
164 sender: current_user,
165 timestamp: OffsetDateTime::now_utc(),
166 mentions: message.mentions.clone(),
167 nonce,
168 },
169 &(),
170 ),
171 cx,
172 );
173 let user_store = self.user_store.clone();
174 let rpc = self.rpc.clone();
175 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
176 Ok(cx.spawn(|this, mut cx| async move {
177 let outgoing_message_guard = outgoing_messages_lock.lock().await;
178 let request = rpc.request(proto::SendChannelMessage {
179 channel_id,
180 body: message.text,
181 nonce: Some(nonce.into()),
182 mentions: mentions_to_proto(&message.mentions),
183 });
184 let response = request.await?;
185 drop(outgoing_message_guard);
186 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
187 let id = response.id;
188 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
189 this.update(&mut cx, |this, cx| {
190 this.insert_messages(SumTree::from_item(message, &()), cx);
191 Ok(id)
192 })
193 }))
194 }
195
196 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
197 let response = self.rpc.request(proto::RemoveChannelMessage {
198 channel_id: self.channel.id,
199 message_id: id,
200 });
201 cx.spawn(|this, mut cx| async move {
202 response.await?;
203
204 this.update(&mut cx, |this, cx| {
205 this.message_removed(id, cx);
206 Ok(())
207 })
208 })
209 }
210
211 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
212 if self.loaded_all_messages {
213 return None;
214 }
215
216 let rpc = self.rpc.clone();
217 let user_store = self.user_store.clone();
218 let channel_id = self.channel.id;
219 let before_message_id = self.first_loaded_message_id()?;
220 Some(cx.spawn(|this, mut cx| {
221 async move {
222 let response = rpc
223 .request(proto::GetChannelMessages {
224 channel_id,
225 before_message_id,
226 })
227 .await?;
228 let loaded_all_messages = response.done;
229 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
230 this.update(&mut cx, |this, cx| {
231 this.loaded_all_messages = loaded_all_messages;
232 this.insert_messages(messages, cx);
233 });
234 anyhow::Ok(())
235 }
236 .log_err()
237 }))
238 }
239
240 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
241 self.messages.first().and_then(|message| match message.id {
242 ChannelMessageId::Saved(id) => Some(id),
243 ChannelMessageId::Pending(_) => None,
244 })
245 }
246
247 /// Load all of the chat messages since a certain message id.
248 ///
249 /// For now, we always maintain a suffix of the channel's messages.
250 pub async fn load_history_since_message(
251 chat: ModelHandle<Self>,
252 message_id: u64,
253 mut cx: AsyncAppContext,
254 ) -> Option<usize> {
255 loop {
256 let step = chat.update(&mut cx, |chat, cx| {
257 if let Some(first_id) = chat.first_loaded_message_id() {
258 if first_id <= message_id {
259 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
260 let message_id = ChannelMessageId::Saved(message_id);
261 cursor.seek(&message_id, Bias::Left, &());
262 return ControlFlow::Break(
263 if cursor
264 .item()
265 .map_or(false, |message| message.id == message_id)
266 {
267 Some(cursor.start().1 .0)
268 } else {
269 None
270 },
271 );
272 }
273 }
274 ControlFlow::Continue(chat.load_more_messages(cx))
275 });
276 match step {
277 ControlFlow::Break(ix) => return ix,
278 ControlFlow::Continue(task) => task?.await?,
279 }
280 }
281 }
282
283 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
284 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
285 if self
286 .last_acknowledged_id
287 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
288 {
289 self.rpc
290 .send(proto::AckChannelMessage {
291 channel_id: self.channel.id,
292 message_id: latest_message_id,
293 })
294 .ok();
295 self.last_acknowledged_id = Some(latest_message_id);
296 self.channel_store.update(cx, |store, cx| {
297 store.acknowledge_message_id(self.channel.id, latest_message_id, cx);
298 });
299 }
300 }
301 }
302
303 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
304 let user_store = self.user_store.clone();
305 let rpc = self.rpc.clone();
306 let channel_id = self.channel.id;
307 cx.spawn(|this, mut cx| {
308 async move {
309 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
310 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
311 let loaded_all_messages = response.done;
312
313 let pending_messages = this.update(&mut cx, |this, cx| {
314 if let Some((first_new_message, last_old_message)) =
315 messages.first().zip(this.messages.last())
316 {
317 if first_new_message.id > last_old_message.id {
318 let old_messages = mem::take(&mut this.messages);
319 cx.emit(ChannelChatEvent::MessagesUpdated {
320 old_range: 0..old_messages.summary().count,
321 new_count: 0,
322 });
323 this.loaded_all_messages = loaded_all_messages;
324 }
325 }
326
327 this.insert_messages(messages, cx);
328 if loaded_all_messages {
329 this.loaded_all_messages = loaded_all_messages;
330 }
331
332 this.pending_messages().cloned().collect::<Vec<_>>()
333 });
334
335 for pending_message in pending_messages {
336 let request = rpc.request(proto::SendChannelMessage {
337 channel_id,
338 body: pending_message.body,
339 mentions: mentions_to_proto(&pending_message.mentions),
340 nonce: Some(pending_message.nonce.into()),
341 });
342 let response = request.await?;
343 let message = ChannelMessage::from_proto(
344 response.message.ok_or_else(|| anyhow!("invalid message"))?,
345 &user_store,
346 &mut cx,
347 )
348 .await?;
349 this.update(&mut cx, |this, cx| {
350 this.insert_messages(SumTree::from_item(message, &()), cx);
351 });
352 }
353
354 anyhow::Ok(())
355 }
356 .log_err()
357 })
358 .detach();
359 }
360
361 pub fn message_count(&self) -> usize {
362 self.messages.summary().count
363 }
364
365 pub fn messages(&self) -> &SumTree<ChannelMessage> {
366 &self.messages
367 }
368
369 pub fn message(&self, ix: usize) -> &ChannelMessage {
370 let mut cursor = self.messages.cursor::<Count>();
371 cursor.seek(&Count(ix), Bias::Right, &());
372 cursor.item().unwrap()
373 }
374
375 pub fn acknowledge_message(&mut self, id: u64) {
376 if self.acknowledged_message_ids.insert(id) {
377 self.rpc
378 .send(proto::AckChannelMessage {
379 channel_id: self.channel.id,
380 message_id: id,
381 })
382 .ok();
383 }
384 }
385
386 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
387 let mut cursor = self.messages.cursor::<Count>();
388 cursor.seek(&Count(range.start), Bias::Right, &());
389 cursor.take(range.len())
390 }
391
392 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
393 let mut cursor = self.messages.cursor::<ChannelMessageId>();
394 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
395 cursor
396 }
397
398 async fn handle_message_sent(
399 this: ModelHandle<Self>,
400 message: TypedEnvelope<proto::ChannelMessageSent>,
401 _: Arc<Client>,
402 mut cx: AsyncAppContext,
403 ) -> Result<()> {
404 let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
405 let message = message
406 .payload
407 .message
408 .ok_or_else(|| anyhow!("empty message"))?;
409 let message_id = message.id;
410
411 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
412 this.update(&mut cx, |this, cx| {
413 this.insert_messages(SumTree::from_item(message, &()), cx);
414 cx.emit(ChannelChatEvent::NewMessage {
415 channel_id: this.channel.id,
416 message_id,
417 })
418 });
419
420 Ok(())
421 }
422
423 async fn handle_message_removed(
424 this: ModelHandle<Self>,
425 message: TypedEnvelope<proto::RemoveChannelMessage>,
426 _: Arc<Client>,
427 mut cx: AsyncAppContext,
428 ) -> Result<()> {
429 this.update(&mut cx, |this, cx| {
430 this.message_removed(message.payload.message_id, cx)
431 });
432 Ok(())
433 }
434
435 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
436 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
437 let nonces = messages
438 .cursor::<()>()
439 .map(|m| m.nonce)
440 .collect::<HashSet<_>>();
441
442 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
443 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
444 let start_ix = old_cursor.start().1 .0;
445 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
446 let removed_count = removed_messages.summary().count;
447 let new_count = messages.summary().count;
448 let end_ix = start_ix + removed_count;
449
450 new_messages.append(messages, &());
451
452 let mut ranges = Vec::<Range<usize>>::new();
453 if new_messages.last().unwrap().is_pending() {
454 new_messages.append(old_cursor.suffix(&()), &());
455 } else {
456 new_messages.append(
457 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
458 &(),
459 );
460
461 while let Some(message) = old_cursor.item() {
462 let message_ix = old_cursor.start().1 .0;
463 if nonces.contains(&message.nonce) {
464 if ranges.last().map_or(false, |r| r.end == message_ix) {
465 ranges.last_mut().unwrap().end += 1;
466 } else {
467 ranges.push(message_ix..message_ix + 1);
468 }
469 } else {
470 new_messages.push(message.clone(), &());
471 }
472 old_cursor.next(&());
473 }
474 }
475
476 drop(old_cursor);
477 self.messages = new_messages;
478
479 for range in ranges.into_iter().rev() {
480 cx.emit(ChannelChatEvent::MessagesUpdated {
481 old_range: range,
482 new_count: 0,
483 });
484 }
485 cx.emit(ChannelChatEvent::MessagesUpdated {
486 old_range: start_ix..end_ix,
487 new_count,
488 });
489
490 cx.notify();
491 }
492 }
493
494 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
495 let mut cursor = self.messages.cursor::<ChannelMessageId>();
496 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
497 if let Some(item) = cursor.item() {
498 if item.id == ChannelMessageId::Saved(id) {
499 let ix = messages.summary().count;
500 cursor.next(&());
501 messages.append(cursor.suffix(&()), &());
502 drop(cursor);
503 self.messages = messages;
504 cx.emit(ChannelChatEvent::MessagesUpdated {
505 old_range: ix..ix + 1,
506 new_count: 0,
507 });
508 }
509 }
510 }
511}
512
513async fn messages_from_proto(
514 proto_messages: Vec<proto::ChannelMessage>,
515 user_store: &ModelHandle<UserStore>,
516 cx: &mut AsyncAppContext,
517) -> Result<SumTree<ChannelMessage>> {
518 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
519 let mut result = SumTree::new();
520 result.extend(messages, &());
521 Ok(result)
522}
523
524impl ChannelMessage {
525 pub async fn from_proto(
526 message: proto::ChannelMessage,
527 user_store: &ModelHandle<UserStore>,
528 cx: &mut AsyncAppContext,
529 ) -> Result<Self> {
530 let sender = user_store
531 .update(cx, |user_store, cx| {
532 user_store.get_user(message.sender_id, cx)
533 })
534 .await?;
535 Ok(ChannelMessage {
536 id: ChannelMessageId::Saved(message.id),
537 body: message.body,
538 mentions: message
539 .mentions
540 .into_iter()
541 .filter_map(|mention| {
542 let range = mention.range?;
543 Some((range.start as usize..range.end as usize, mention.user_id))
544 })
545 .collect(),
546 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
547 sender,
548 nonce: message
549 .nonce
550 .ok_or_else(|| anyhow!("nonce is required"))?
551 .into(),
552 })
553 }
554
555 pub fn is_pending(&self) -> bool {
556 matches!(self.id, ChannelMessageId::Pending(_))
557 }
558
559 pub async fn from_proto_vec(
560 proto_messages: Vec<proto::ChannelMessage>,
561 user_store: &ModelHandle<UserStore>,
562 cx: &mut AsyncAppContext,
563 ) -> Result<Vec<Self>> {
564 let unique_user_ids = proto_messages
565 .iter()
566 .map(|m| m.sender_id)
567 .collect::<HashSet<_>>()
568 .into_iter()
569 .collect();
570 user_store
571 .update(cx, |user_store, cx| {
572 user_store.get_users(unique_user_ids, cx)
573 })
574 .await?;
575
576 let mut messages = Vec::with_capacity(proto_messages.len());
577 for message in proto_messages {
578 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
579 }
580 Ok(messages)
581 }
582}
583
584pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
585 mentions
586 .iter()
587 .map(|(range, user_id)| proto::ChatMention {
588 range: Some(proto::Range {
589 start: range.start as u64,
590 end: range.end as u64,
591 }),
592 user_id: *user_id as u64,
593 })
594 .collect()
595}
596
597impl sum_tree::Item for ChannelMessage {
598 type Summary = ChannelMessageSummary;
599
600 fn summary(&self) -> Self::Summary {
601 ChannelMessageSummary {
602 max_id: self.id,
603 count: 1,
604 }
605 }
606}
607
608impl Default for ChannelMessageId {
609 fn default() -> Self {
610 Self::Saved(0)
611 }
612}
613
614impl sum_tree::Summary for ChannelMessageSummary {
615 type Context = ();
616
617 fn add_summary(&mut self, summary: &Self, _: &()) {
618 self.max_id = summary.max_id;
619 self.count += summary.count;
620 }
621}
622
623impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
624 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
625 debug_assert!(summary.max_id > *self);
626 *self = summary.max_id;
627 }
628}
629
630impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
631 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
632 self.0 += summary.count;
633 }
634}
635
636impl<'a> From<&'a str> for MessageParams {
637 fn from(value: &'a str) -> Self {
638 Self {
639 text: value.into(),
640 mentions: Vec::new(),
641 }
642 }
643}