1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope, UserId,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
10use rand::prelude::*;
11use std::{
12 collections::HashSet,
13 mem,
14 ops::{ControlFlow, Range},
15 sync::Arc,
16};
17use sum_tree::{Bias, SumTree};
18use time::OffsetDateTime;
19use util::{post_inc, ResultExt as _, TryFutureExt};
20
21pub struct ChannelChat {
22 pub channel_id: ChannelId,
23 messages: SumTree<ChannelMessage>,
24 acknowledged_message_ids: HashSet<u64>,
25 channel_store: Model<ChannelStore>,
26 loaded_all_messages: bool,
27 last_acknowledged_id: Option<u64>,
28 next_pending_message_id: usize,
29 user_store: Model<UserStore>,
30 rpc: Arc<Client>,
31 outgoing_messages_lock: Arc<Mutex<()>>,
32 rng: StdRng,
33 _subscription: Subscription,
34}
35
36#[derive(Debug, PartialEq, Eq)]
37pub struct MessageParams {
38 pub text: String,
39 pub mentions: Vec<(Range<usize>, UserId)>,
40}
41
42#[derive(Clone, Debug)]
43pub struct ChannelMessage {
44 pub id: ChannelMessageId,
45 pub body: String,
46 pub timestamp: OffsetDateTime,
47 pub sender: Arc<User>,
48 pub nonce: u128,
49 pub mentions: Vec<(Range<usize>, UserId)>,
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
53pub enum ChannelMessageId {
54 Saved(u64),
55 Pending(usize),
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct ChannelMessageSummary {
60 max_id: ChannelMessageId,
61 count: usize,
62}
63
64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Debug, PartialEq)]
68pub enum ChannelChatEvent {
69 MessagesUpdated {
70 old_range: Range<usize>,
71 new_count: usize,
72 },
73 NewMessage {
74 channel_id: ChannelId,
75 message_id: u64,
76 },
77}
78
79impl EventEmitter<ChannelChatEvent> for ChannelChat {}
80pub fn init(client: &Arc<Client>) {
81 client.add_model_message_handler(ChannelChat::handle_message_sent);
82 client.add_model_message_handler(ChannelChat::handle_message_removed);
83}
84
85impl ChannelChat {
86 pub async fn new(
87 channel: Arc<Channel>,
88 channel_store: Model<ChannelStore>,
89 user_store: Model<UserStore>,
90 client: Arc<Client>,
91 mut cx: AsyncAppContext,
92 ) -> Result<Model<Self>> {
93 let channel_id = channel.id;
94 let subscription = client.subscribe_to_entity(channel_id).unwrap();
95
96 let response = client
97 .request(proto::JoinChannelChat { channel_id })
98 .await?;
99 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
100 let loaded_all_messages = response.done;
101
102 Ok(cx.build_model(|cx| {
103 cx.on_release(Self::release).detach();
104 let mut this = Self {
105 channel_id: channel.id,
106 user_store,
107 channel_store,
108 rpc: client,
109 outgoing_messages_lock: Default::default(),
110 messages: Default::default(),
111 acknowledged_message_ids: Default::default(),
112 loaded_all_messages,
113 next_pending_message_id: 0,
114 last_acknowledged_id: None,
115 rng: StdRng::from_entropy(),
116 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
117 };
118 this.insert_messages(messages, cx);
119 this
120 })?)
121 }
122
123 fn release(&mut self, _: &mut AppContext) {
124 self.rpc
125 .send(proto::LeaveChannelChat {
126 channel_id: self.channel_id,
127 })
128 .log_err();
129 }
130
131 pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
132 self.channel_store
133 .read(cx)
134 .channel_for_id(self.channel_id)
135 .cloned()
136 }
137
138 pub fn client(&self) -> &Arc<Client> {
139 &self.rpc
140 }
141
142 pub fn send_message(
143 &mut self,
144 message: MessageParams,
145 cx: &mut ModelContext<Self>,
146 ) -> Result<Task<Result<u64>>> {
147 if message.text.is_empty() {
148 Err(anyhow!("message body can't be empty"))?;
149 }
150
151 let current_user = self
152 .user_store
153 .read(cx)
154 .current_user()
155 .ok_or_else(|| anyhow!("current_user is not present"))?;
156
157 let channel_id = self.channel_id;
158 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
159 let nonce = self.rng.gen();
160 self.insert_messages(
161 SumTree::from_item(
162 ChannelMessage {
163 id: pending_id,
164 body: message.text.clone(),
165 sender: current_user,
166 timestamp: OffsetDateTime::now_utc(),
167 mentions: message.mentions.clone(),
168 nonce,
169 },
170 &(),
171 ),
172 cx,
173 );
174 let user_store = self.user_store.clone();
175 let rpc = self.rpc.clone();
176 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
177 Ok(cx.spawn(move |this, mut cx| async move {
178 let outgoing_message_guard = outgoing_messages_lock.lock().await;
179 let request = rpc.request(proto::SendChannelMessage {
180 channel_id,
181 body: message.text,
182 nonce: Some(nonce.into()),
183 mentions: mentions_to_proto(&message.mentions),
184 });
185 let response = request.await?;
186 drop(outgoing_message_guard);
187 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
188 let id = response.id;
189 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
190 this.update(&mut cx, |this, cx| {
191 this.insert_messages(SumTree::from_item(message, &()), cx);
192 })?;
193 Ok(id)
194 }))
195 }
196
197 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
198 let response = self.rpc.request(proto::RemoveChannelMessage {
199 channel_id: self.channel_id,
200 message_id: id,
201 });
202 cx.spawn(move |this, mut cx| async move {
203 response.await?;
204 this.update(&mut cx, |this, cx| {
205 this.message_removed(id, cx);
206 })?;
207 Ok(())
208 })
209 }
210
211 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
212 if self.loaded_all_messages {
213 return None;
214 }
215
216 let rpc = self.rpc.clone();
217 let user_store = self.user_store.clone();
218 let channel_id = self.channel_id;
219 let before_message_id = self.first_loaded_message_id()?;
220 Some(cx.spawn(move |this, mut cx| {
221 async move {
222 let response = rpc
223 .request(proto::GetChannelMessages {
224 channel_id,
225 before_message_id,
226 })
227 .await?;
228 let loaded_all_messages = response.done;
229 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
230 this.update(&mut cx, |this, cx| {
231 this.loaded_all_messages = loaded_all_messages;
232 this.insert_messages(messages, cx);
233 })?;
234 anyhow::Ok(())
235 }
236 .log_err()
237 }))
238 }
239
240 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
241 self.messages.first().and_then(|message| match message.id {
242 ChannelMessageId::Saved(id) => Some(id),
243 ChannelMessageId::Pending(_) => None,
244 })
245 }
246
247 /// Load all of the chat messages since a certain message id.
248 ///
249 /// For now, we always maintain a suffix of the channel's messages.
250 pub async fn load_history_since_message(
251 chat: Model<Self>,
252 message_id: u64,
253 mut cx: AsyncAppContext,
254 ) -> Option<usize> {
255 loop {
256 let step = chat
257 .update(&mut cx, |chat, cx| {
258 if let Some(first_id) = chat.first_loaded_message_id() {
259 if first_id <= message_id {
260 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
261 let message_id = ChannelMessageId::Saved(message_id);
262 cursor.seek(&message_id, Bias::Left, &());
263 return ControlFlow::Break(
264 if cursor
265 .item()
266 .map_or(false, |message| message.id == message_id)
267 {
268 Some(cursor.start().1 .0)
269 } else {
270 None
271 },
272 );
273 }
274 }
275 ControlFlow::Continue(chat.load_more_messages(cx))
276 })
277 .log_err()?;
278 match step {
279 ControlFlow::Break(ix) => return ix,
280 ControlFlow::Continue(task) => task?.await?,
281 }
282 }
283 }
284
285 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
286 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
287 if self
288 .last_acknowledged_id
289 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
290 {
291 self.rpc
292 .send(proto::AckChannelMessage {
293 channel_id: self.channel_id,
294 message_id: latest_message_id,
295 })
296 .ok();
297 self.last_acknowledged_id = Some(latest_message_id);
298 self.channel_store.update(cx, |store, cx| {
299 store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
300 });
301 }
302 }
303 }
304
305 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
306 let user_store = self.user_store.clone();
307 let rpc = self.rpc.clone();
308 let channel_id = self.channel_id;
309 cx.spawn(move |this, mut cx| {
310 async move {
311 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
312 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
313 let loaded_all_messages = response.done;
314
315 let pending_messages = this.update(&mut cx, |this, cx| {
316 if let Some((first_new_message, last_old_message)) =
317 messages.first().zip(this.messages.last())
318 {
319 if first_new_message.id > last_old_message.id {
320 let old_messages = mem::take(&mut this.messages);
321 cx.emit(ChannelChatEvent::MessagesUpdated {
322 old_range: 0..old_messages.summary().count,
323 new_count: 0,
324 });
325 this.loaded_all_messages = loaded_all_messages;
326 }
327 }
328
329 this.insert_messages(messages, cx);
330 if loaded_all_messages {
331 this.loaded_all_messages = loaded_all_messages;
332 }
333
334 this.pending_messages().cloned().collect::<Vec<_>>()
335 })?;
336
337 for pending_message in pending_messages {
338 let request = rpc.request(proto::SendChannelMessage {
339 channel_id,
340 body: pending_message.body,
341 mentions: mentions_to_proto(&pending_message.mentions),
342 nonce: Some(pending_message.nonce.into()),
343 });
344 let response = request.await?;
345 let message = ChannelMessage::from_proto(
346 response.message.ok_or_else(|| anyhow!("invalid message"))?,
347 &user_store,
348 &mut cx,
349 )
350 .await?;
351 this.update(&mut cx, |this, cx| {
352 this.insert_messages(SumTree::from_item(message, &()), cx);
353 })?;
354 }
355
356 anyhow::Ok(())
357 }
358 .log_err()
359 })
360 .detach();
361 }
362
363 pub fn message_count(&self) -> usize {
364 self.messages.summary().count
365 }
366
367 pub fn messages(&self) -> &SumTree<ChannelMessage> {
368 &self.messages
369 }
370
371 pub fn message(&self, ix: usize) -> &ChannelMessage {
372 let mut cursor = self.messages.cursor::<Count>();
373 cursor.seek(&Count(ix), Bias::Right, &());
374 cursor.item().unwrap()
375 }
376
377 pub fn acknowledge_message(&mut self, id: u64) {
378 if self.acknowledged_message_ids.insert(id) {
379 self.rpc
380 .send(proto::AckChannelMessage {
381 channel_id: self.channel_id,
382 message_id: id,
383 })
384 .ok();
385 }
386 }
387
388 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
389 let mut cursor = self.messages.cursor::<Count>();
390 cursor.seek(&Count(range.start), Bias::Right, &());
391 cursor.take(range.len())
392 }
393
394 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
395 let mut cursor = self.messages.cursor::<ChannelMessageId>();
396 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
397 cursor
398 }
399
400 async fn handle_message_sent(
401 this: Model<Self>,
402 message: TypedEnvelope<proto::ChannelMessageSent>,
403 _: Arc<Client>,
404 mut cx: AsyncAppContext,
405 ) -> Result<()> {
406 let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
407 let message = message
408 .payload
409 .message
410 .ok_or_else(|| anyhow!("empty message"))?;
411 let message_id = message.id;
412
413 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
414 this.update(&mut cx, |this, cx| {
415 this.insert_messages(SumTree::from_item(message, &()), cx);
416 cx.emit(ChannelChatEvent::NewMessage {
417 channel_id: this.channel_id,
418 message_id,
419 })
420 })?;
421
422 Ok(())
423 }
424
425 async fn handle_message_removed(
426 this: Model<Self>,
427 message: TypedEnvelope<proto::RemoveChannelMessage>,
428 _: Arc<Client>,
429 mut cx: AsyncAppContext,
430 ) -> Result<()> {
431 this.update(&mut cx, |this, cx| {
432 this.message_removed(message.payload.message_id, cx)
433 })?;
434 Ok(())
435 }
436
437 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
438 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
439 let nonces = messages
440 .cursor::<()>()
441 .map(|m| m.nonce)
442 .collect::<HashSet<_>>();
443
444 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
445 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
446 let start_ix = old_cursor.start().1 .0;
447 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
448 let removed_count = removed_messages.summary().count;
449 let new_count = messages.summary().count;
450 let end_ix = start_ix + removed_count;
451
452 new_messages.append(messages, &());
453
454 let mut ranges = Vec::<Range<usize>>::new();
455 if new_messages.last().unwrap().is_pending() {
456 new_messages.append(old_cursor.suffix(&()), &());
457 } else {
458 new_messages.append(
459 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
460 &(),
461 );
462
463 while let Some(message) = old_cursor.item() {
464 let message_ix = old_cursor.start().1 .0;
465 if nonces.contains(&message.nonce) {
466 if ranges.last().map_or(false, |r| r.end == message_ix) {
467 ranges.last_mut().unwrap().end += 1;
468 } else {
469 ranges.push(message_ix..message_ix + 1);
470 }
471 } else {
472 new_messages.push(message.clone(), &());
473 }
474 old_cursor.next(&());
475 }
476 }
477
478 drop(old_cursor);
479 self.messages = new_messages;
480
481 for range in ranges.into_iter().rev() {
482 cx.emit(ChannelChatEvent::MessagesUpdated {
483 old_range: range,
484 new_count: 0,
485 });
486 }
487 cx.emit(ChannelChatEvent::MessagesUpdated {
488 old_range: start_ix..end_ix,
489 new_count,
490 });
491
492 cx.notify();
493 }
494 }
495
496 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
497 let mut cursor = self.messages.cursor::<ChannelMessageId>();
498 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
499 if let Some(item) = cursor.item() {
500 if item.id == ChannelMessageId::Saved(id) {
501 let ix = messages.summary().count;
502 cursor.next(&());
503 messages.append(cursor.suffix(&()), &());
504 drop(cursor);
505 self.messages = messages;
506 cx.emit(ChannelChatEvent::MessagesUpdated {
507 old_range: ix..ix + 1,
508 new_count: 0,
509 });
510 }
511 }
512 }
513}
514
515async fn messages_from_proto(
516 proto_messages: Vec<proto::ChannelMessage>,
517 user_store: &Model<UserStore>,
518 cx: &mut AsyncAppContext,
519) -> Result<SumTree<ChannelMessage>> {
520 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
521 let mut result = SumTree::new();
522 result.extend(messages, &());
523 Ok(result)
524}
525
526impl ChannelMessage {
527 pub async fn from_proto(
528 message: proto::ChannelMessage,
529 user_store: &Model<UserStore>,
530 cx: &mut AsyncAppContext,
531 ) -> Result<Self> {
532 let sender = user_store
533 .update(cx, |user_store, cx| {
534 user_store.get_user(message.sender_id, cx)
535 })?
536 .await?;
537 Ok(ChannelMessage {
538 id: ChannelMessageId::Saved(message.id),
539 body: message.body,
540 mentions: message
541 .mentions
542 .into_iter()
543 .filter_map(|mention| {
544 let range = mention.range?;
545 Some((range.start as usize..range.end as usize, mention.user_id))
546 })
547 .collect(),
548 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
549 sender,
550 nonce: message
551 .nonce
552 .ok_or_else(|| anyhow!("nonce is required"))?
553 .into(),
554 })
555 }
556
557 pub fn is_pending(&self) -> bool {
558 matches!(self.id, ChannelMessageId::Pending(_))
559 }
560
561 pub async fn from_proto_vec(
562 proto_messages: Vec<proto::ChannelMessage>,
563 user_store: &Model<UserStore>,
564 cx: &mut AsyncAppContext,
565 ) -> Result<Vec<Self>> {
566 let unique_user_ids = proto_messages
567 .iter()
568 .map(|m| m.sender_id)
569 .collect::<HashSet<_>>()
570 .into_iter()
571 .collect();
572 user_store
573 .update(cx, |user_store, cx| {
574 user_store.get_users(unique_user_ids, cx)
575 })?
576 .await?;
577
578 let mut messages = Vec::with_capacity(proto_messages.len());
579 for message in proto_messages {
580 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
581 }
582 Ok(messages)
583 }
584}
585
586pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
587 mentions
588 .iter()
589 .map(|(range, user_id)| proto::ChatMention {
590 range: Some(proto::Range {
591 start: range.start as u64,
592 end: range.end as u64,
593 }),
594 user_id: *user_id as u64,
595 })
596 .collect()
597}
598
599impl sum_tree::Item for ChannelMessage {
600 type Summary = ChannelMessageSummary;
601
602 fn summary(&self) -> Self::Summary {
603 ChannelMessageSummary {
604 max_id: self.id,
605 count: 1,
606 }
607 }
608}
609
610impl Default for ChannelMessageId {
611 fn default() -> Self {
612 Self::Saved(0)
613 }
614}
615
616impl sum_tree::Summary for ChannelMessageSummary {
617 type Context = ();
618
619 fn add_summary(&mut self, summary: &Self, _: &()) {
620 self.max_id = summary.max_id;
621 self.count += summary.count;
622 }
623}
624
625impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
626 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
627 debug_assert!(summary.max_id > *self);
628 *self = summary.max_id;
629 }
630}
631
632impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
633 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
634 self.0 += summary.count;
635 }
636}
637
638impl<'a> From<&'a str> for MessageParams {
639 fn from(value: &'a str) -> Self {
640 Self {
641 text: value.into(),
642 mentions: Vec::new(),
643 }
644 }
645}