1use crate::{Channel, ChannelId, ChannelStore};
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope, UserId,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
10use rand::prelude::*;
11use std::{
12 collections::HashSet,
13 mem,
14 ops::{ControlFlow, Range},
15 sync::Arc,
16};
17use sum_tree::{Bias, SumTree};
18use time::OffsetDateTime;
19use util::{post_inc, ResultExt as _, TryFutureExt};
20
21pub struct ChannelChat {
22 pub channel_id: ChannelId,
23 messages: SumTree<ChannelMessage>,
24 acknowledged_message_ids: HashSet<u64>,
25 channel_store: Model<ChannelStore>,
26 loaded_all_messages: bool,
27 last_acknowledged_id: Option<u64>,
28 next_pending_message_id: usize,
29 user_store: Model<UserStore>,
30 rpc: Arc<Client>,
31 outgoing_messages_lock: Arc<Mutex<()>>,
32 rng: StdRng,
33 _subscription: Subscription,
34}
35
36#[derive(Debug, PartialEq, Eq)]
37pub struct MessageParams {
38 pub text: String,
39 pub mentions: Vec<(Range<usize>, UserId)>,
40}
41
42#[derive(Clone, Debug)]
43pub struct ChannelMessage {
44 pub id: ChannelMessageId,
45 pub body: String,
46 pub timestamp: OffsetDateTime,
47 pub sender: Arc<User>,
48 pub nonce: u128,
49 pub mentions: Vec<(Range<usize>, UserId)>,
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
53pub enum ChannelMessageId {
54 Saved(u64),
55 Pending(usize),
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct ChannelMessageSummary {
60 max_id: ChannelMessageId,
61 count: usize,
62}
63
64#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Debug, PartialEq)]
68pub enum ChannelChatEvent {
69 MessagesUpdated {
70 old_range: Range<usize>,
71 new_count: usize,
72 },
73 NewMessage {
74 channel_id: ChannelId,
75 message_id: u64,
76 },
77}
78
79impl EventEmitter<ChannelChatEvent> for ChannelChat {}
80pub fn init(client: &Arc<Client>) {
81 client.add_model_message_handler(ChannelChat::handle_message_sent);
82 client.add_model_message_handler(ChannelChat::handle_message_removed);
83}
84
85impl ChannelChat {
86 pub async fn new(
87 channel: Arc<Channel>,
88 channel_store: Model<ChannelStore>,
89 user_store: Model<UserStore>,
90 client: Arc<Client>,
91 mut cx: AsyncAppContext,
92 ) -> Result<Model<Self>> {
93 let channel_id = channel.id;
94 let subscription = client.subscribe_to_entity(channel_id).unwrap();
95
96 let response = client
97 .request(proto::JoinChannelChat { channel_id })
98 .await?;
99 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
100 let loaded_all_messages = response.done;
101
102 Ok(cx.new_model(|cx| {
103 cx.on_release(Self::release).detach();
104 let mut this = Self {
105 channel_id: channel.id,
106 user_store,
107 channel_store,
108 rpc: client,
109 outgoing_messages_lock: Default::default(),
110 messages: Default::default(),
111 acknowledged_message_ids: Default::default(),
112 loaded_all_messages,
113 next_pending_message_id: 0,
114 last_acknowledged_id: None,
115 rng: StdRng::from_entropy(),
116 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
117 };
118 this.insert_messages(messages, cx);
119 this
120 })?)
121 }
122
123 fn release(&mut self, _: &mut AppContext) {
124 self.rpc
125 .send(proto::LeaveChannelChat {
126 channel_id: self.channel_id,
127 })
128 .log_err();
129 }
130
131 pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
132 self.channel_store
133 .read(cx)
134 .channel_for_id(self.channel_id)
135 .cloned()
136 }
137
138 pub fn client(&self) -> &Arc<Client> {
139 &self.rpc
140 }
141
142 pub fn send_message(
143 &mut self,
144 message: MessageParams,
145 cx: &mut ModelContext<Self>,
146 ) -> Result<Task<Result<u64>>> {
147 if message.text.trim().is_empty() {
148 Err(anyhow!("message body can't be empty"))?;
149 }
150
151 let current_user = self
152 .user_store
153 .read(cx)
154 .current_user()
155 .ok_or_else(|| anyhow!("current_user is not present"))?;
156
157 let channel_id = self.channel_id;
158 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
159 let nonce = self.rng.gen();
160 self.insert_messages(
161 SumTree::from_item(
162 ChannelMessage {
163 id: pending_id,
164 body: message.text.clone(),
165 sender: current_user,
166 timestamp: OffsetDateTime::now_utc(),
167 mentions: message.mentions.clone(),
168 nonce,
169 },
170 &(),
171 ),
172 cx,
173 );
174 let user_store = self.user_store.clone();
175 let rpc = self.rpc.clone();
176 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
177
178 // todo - handle messages that fail to send (e.g. >1024 chars)
179 Ok(cx.spawn(move |this, mut cx| async move {
180 let outgoing_message_guard = outgoing_messages_lock.lock().await;
181 let request = rpc.request(proto::SendChannelMessage {
182 channel_id,
183 body: message.text,
184 nonce: Some(nonce.into()),
185 mentions: mentions_to_proto(&message.mentions),
186 });
187 let response = request.await?;
188 drop(outgoing_message_guard);
189 let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
190 let id = response.id;
191 let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
192 this.update(&mut cx, |this, cx| {
193 this.insert_messages(SumTree::from_item(message, &()), cx);
194 })?;
195 Ok(id)
196 }))
197 }
198
199 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
200 let response = self.rpc.request(proto::RemoveChannelMessage {
201 channel_id: self.channel_id,
202 message_id: id,
203 });
204 cx.spawn(move |this, mut cx| async move {
205 response.await?;
206 this.update(&mut cx, |this, cx| {
207 this.message_removed(id, cx);
208 })?;
209 Ok(())
210 })
211 }
212
213 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
214 if self.loaded_all_messages {
215 return None;
216 }
217
218 let rpc = self.rpc.clone();
219 let user_store = self.user_store.clone();
220 let channel_id = self.channel_id;
221 let before_message_id = self.first_loaded_message_id()?;
222 Some(cx.spawn(move |this, mut cx| {
223 async move {
224 let response = rpc
225 .request(proto::GetChannelMessages {
226 channel_id,
227 before_message_id,
228 })
229 .await?;
230 let loaded_all_messages = response.done;
231 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
232 this.update(&mut cx, |this, cx| {
233 this.loaded_all_messages = loaded_all_messages;
234 this.insert_messages(messages, cx);
235 })?;
236 anyhow::Ok(())
237 }
238 .log_err()
239 }))
240 }
241
242 pub fn first_loaded_message_id(&mut self) -> Option<u64> {
243 self.messages.first().and_then(|message| match message.id {
244 ChannelMessageId::Saved(id) => Some(id),
245 ChannelMessageId::Pending(_) => None,
246 })
247 }
248
249 /// Load all of the chat messages since a certain message id.
250 ///
251 /// For now, we always maintain a suffix of the channel's messages.
252 pub async fn load_history_since_message(
253 chat: Model<Self>,
254 message_id: u64,
255 mut cx: AsyncAppContext,
256 ) -> Option<usize> {
257 loop {
258 let step = chat
259 .update(&mut cx, |chat, cx| {
260 if let Some(first_id) = chat.first_loaded_message_id() {
261 if first_id <= message_id {
262 let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
263 let message_id = ChannelMessageId::Saved(message_id);
264 cursor.seek(&message_id, Bias::Left, &());
265 return ControlFlow::Break(
266 if cursor
267 .item()
268 .map_or(false, |message| message.id == message_id)
269 {
270 Some(cursor.start().1 .0)
271 } else {
272 None
273 },
274 );
275 }
276 }
277 ControlFlow::Continue(chat.load_more_messages(cx))
278 })
279 .log_err()?;
280 match step {
281 ControlFlow::Break(ix) => return ix,
282 ControlFlow::Continue(task) => task?.await?,
283 }
284 }
285 }
286
287 pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
288 if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
289 if self
290 .last_acknowledged_id
291 .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
292 {
293 self.rpc
294 .send(proto::AckChannelMessage {
295 channel_id: self.channel_id,
296 message_id: latest_message_id,
297 })
298 .ok();
299 self.last_acknowledged_id = Some(latest_message_id);
300 self.channel_store.update(cx, |store, cx| {
301 store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
302 });
303 }
304 }
305 }
306
307 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
308 let user_store = self.user_store.clone();
309 let rpc = self.rpc.clone();
310 let channel_id = self.channel_id;
311 cx.spawn(move |this, mut cx| {
312 async move {
313 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
314 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
315 let loaded_all_messages = response.done;
316
317 let pending_messages = this.update(&mut cx, |this, cx| {
318 if let Some((first_new_message, last_old_message)) =
319 messages.first().zip(this.messages.last())
320 {
321 if first_new_message.id > last_old_message.id {
322 let old_messages = mem::take(&mut this.messages);
323 cx.emit(ChannelChatEvent::MessagesUpdated {
324 old_range: 0..old_messages.summary().count,
325 new_count: 0,
326 });
327 this.loaded_all_messages = loaded_all_messages;
328 }
329 }
330
331 this.insert_messages(messages, cx);
332 if loaded_all_messages {
333 this.loaded_all_messages = loaded_all_messages;
334 }
335
336 this.pending_messages().cloned().collect::<Vec<_>>()
337 })?;
338
339 for pending_message in pending_messages {
340 let request = rpc.request(proto::SendChannelMessage {
341 channel_id,
342 body: pending_message.body,
343 mentions: mentions_to_proto(&pending_message.mentions),
344 nonce: Some(pending_message.nonce.into()),
345 });
346 let response = request.await?;
347 let message = ChannelMessage::from_proto(
348 response.message.ok_or_else(|| anyhow!("invalid message"))?,
349 &user_store,
350 &mut cx,
351 )
352 .await?;
353 this.update(&mut cx, |this, cx| {
354 this.insert_messages(SumTree::from_item(message, &()), cx);
355 })?;
356 }
357
358 anyhow::Ok(())
359 }
360 .log_err()
361 })
362 .detach();
363 }
364
365 pub fn message_count(&self) -> usize {
366 self.messages.summary().count
367 }
368
369 pub fn messages(&self) -> &SumTree<ChannelMessage> {
370 &self.messages
371 }
372
373 pub fn message(&self, ix: usize) -> &ChannelMessage {
374 let mut cursor = self.messages.cursor::<Count>();
375 cursor.seek(&Count(ix), Bias::Right, &());
376 cursor.item().unwrap()
377 }
378
379 pub fn acknowledge_message(&mut self, id: u64) {
380 if self.acknowledged_message_ids.insert(id) {
381 self.rpc
382 .send(proto::AckChannelMessage {
383 channel_id: self.channel_id,
384 message_id: id,
385 })
386 .ok();
387 }
388 }
389
390 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
391 let mut cursor = self.messages.cursor::<Count>();
392 cursor.seek(&Count(range.start), Bias::Right, &());
393 cursor.take(range.len())
394 }
395
396 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
397 let mut cursor = self.messages.cursor::<ChannelMessageId>();
398 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
399 cursor
400 }
401
402 async fn handle_message_sent(
403 this: Model<Self>,
404 message: TypedEnvelope<proto::ChannelMessageSent>,
405 _: Arc<Client>,
406 mut cx: AsyncAppContext,
407 ) -> Result<()> {
408 let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
409 let message = message
410 .payload
411 .message
412 .ok_or_else(|| anyhow!("empty message"))?;
413 let message_id = message.id;
414
415 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
416 this.update(&mut cx, |this, cx| {
417 this.insert_messages(SumTree::from_item(message, &()), cx);
418 cx.emit(ChannelChatEvent::NewMessage {
419 channel_id: this.channel_id,
420 message_id,
421 })
422 })?;
423
424 Ok(())
425 }
426
427 async fn handle_message_removed(
428 this: Model<Self>,
429 message: TypedEnvelope<proto::RemoveChannelMessage>,
430 _: Arc<Client>,
431 mut cx: AsyncAppContext,
432 ) -> Result<()> {
433 this.update(&mut cx, |this, cx| {
434 this.message_removed(message.payload.message_id, cx)
435 })?;
436 Ok(())
437 }
438
439 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
440 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
441 let nonces = messages
442 .cursor::<()>()
443 .map(|m| m.nonce)
444 .collect::<HashSet<_>>();
445
446 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
447 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
448 let start_ix = old_cursor.start().1 .0;
449 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
450 let removed_count = removed_messages.summary().count;
451 let new_count = messages.summary().count;
452 let end_ix = start_ix + removed_count;
453
454 new_messages.append(messages, &());
455
456 let mut ranges = Vec::<Range<usize>>::new();
457 if new_messages.last().unwrap().is_pending() {
458 new_messages.append(old_cursor.suffix(&()), &());
459 } else {
460 new_messages.append(
461 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
462 &(),
463 );
464
465 while let Some(message) = old_cursor.item() {
466 let message_ix = old_cursor.start().1 .0;
467 if nonces.contains(&message.nonce) {
468 if ranges.last().map_or(false, |r| r.end == message_ix) {
469 ranges.last_mut().unwrap().end += 1;
470 } else {
471 ranges.push(message_ix..message_ix + 1);
472 }
473 } else {
474 new_messages.push(message.clone(), &());
475 }
476 old_cursor.next(&());
477 }
478 }
479
480 drop(old_cursor);
481 self.messages = new_messages;
482
483 for range in ranges.into_iter().rev() {
484 cx.emit(ChannelChatEvent::MessagesUpdated {
485 old_range: range,
486 new_count: 0,
487 });
488 }
489 cx.emit(ChannelChatEvent::MessagesUpdated {
490 old_range: start_ix..end_ix,
491 new_count,
492 });
493
494 cx.notify();
495 }
496 }
497
498 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
499 let mut cursor = self.messages.cursor::<ChannelMessageId>();
500 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
501 if let Some(item) = cursor.item() {
502 if item.id == ChannelMessageId::Saved(id) {
503 let ix = messages.summary().count;
504 cursor.next(&());
505 messages.append(cursor.suffix(&()), &());
506 drop(cursor);
507 self.messages = messages;
508 cx.emit(ChannelChatEvent::MessagesUpdated {
509 old_range: ix..ix + 1,
510 new_count: 0,
511 });
512 }
513 }
514 }
515}
516
517async fn messages_from_proto(
518 proto_messages: Vec<proto::ChannelMessage>,
519 user_store: &Model<UserStore>,
520 cx: &mut AsyncAppContext,
521) -> Result<SumTree<ChannelMessage>> {
522 let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
523 let mut result = SumTree::new();
524 result.extend(messages, &());
525 Ok(result)
526}
527
528impl ChannelMessage {
529 pub async fn from_proto(
530 message: proto::ChannelMessage,
531 user_store: &Model<UserStore>,
532 cx: &mut AsyncAppContext,
533 ) -> Result<Self> {
534 let sender = user_store
535 .update(cx, |user_store, cx| {
536 user_store.get_user(message.sender_id, cx)
537 })?
538 .await?;
539 Ok(ChannelMessage {
540 id: ChannelMessageId::Saved(message.id),
541 body: message.body,
542 mentions: message
543 .mentions
544 .into_iter()
545 .filter_map(|mention| {
546 let range = mention.range?;
547 Some((range.start as usize..range.end as usize, mention.user_id))
548 })
549 .collect(),
550 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
551 sender,
552 nonce: message
553 .nonce
554 .ok_or_else(|| anyhow!("nonce is required"))?
555 .into(),
556 })
557 }
558
559 pub fn is_pending(&self) -> bool {
560 matches!(self.id, ChannelMessageId::Pending(_))
561 }
562
563 pub async fn from_proto_vec(
564 proto_messages: Vec<proto::ChannelMessage>,
565 user_store: &Model<UserStore>,
566 cx: &mut AsyncAppContext,
567 ) -> Result<Vec<Self>> {
568 let unique_user_ids = proto_messages
569 .iter()
570 .map(|m| m.sender_id)
571 .collect::<HashSet<_>>()
572 .into_iter()
573 .collect();
574 user_store
575 .update(cx, |user_store, cx| {
576 user_store.get_users(unique_user_ids, cx)
577 })?
578 .await?;
579
580 let mut messages = Vec::with_capacity(proto_messages.len());
581 for message in proto_messages {
582 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
583 }
584 Ok(messages)
585 }
586}
587
588pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
589 mentions
590 .iter()
591 .map(|(range, user_id)| proto::ChatMention {
592 range: Some(proto::Range {
593 start: range.start as u64,
594 end: range.end as u64,
595 }),
596 user_id: *user_id as u64,
597 })
598 .collect()
599}
600
601impl sum_tree::Item for ChannelMessage {
602 type Summary = ChannelMessageSummary;
603
604 fn summary(&self) -> Self::Summary {
605 ChannelMessageSummary {
606 max_id: self.id,
607 count: 1,
608 }
609 }
610}
611
612impl Default for ChannelMessageId {
613 fn default() -> Self {
614 Self::Saved(0)
615 }
616}
617
618impl sum_tree::Summary for ChannelMessageSummary {
619 type Context = ();
620
621 fn add_summary(&mut self, summary: &Self, _: &()) {
622 self.max_id = summary.max_id;
623 self.count += summary.count;
624 }
625}
626
627impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
628 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
629 debug_assert!(summary.max_id > *self);
630 *self = summary.max_id;
631 }
632}
633
634impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
635 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
636 self.0 += summary.count;
637 }
638}
639
640impl<'a> From<&'a str> for MessageParams {
641 fn from(value: &'a str) -> Self {
642 Self {
643 text: value.into(),
644 mentions: Vec::new(),
645 }
646 }
647}