1use crate::Channel;
2use anyhow::{anyhow, Result};
3use client::{
4 proto,
5 user::{User, UserStore},
6 Client, Subscription, TypedEnvelope,
7};
8use futures::lock::Mutex;
9use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use rand::prelude::*;
11use std::{collections::HashSet, mem, ops::Range, sync::Arc};
12use sum_tree::{Bias, SumTree};
13use time::OffsetDateTime;
14use util::{post_inc, ResultExt as _, TryFutureExt};
15
16pub struct ChannelChat {
17 channel: Arc<Channel>,
18 messages: SumTree<ChannelMessage>,
19 loaded_all_messages: bool,
20 next_pending_message_id: usize,
21 user_store: ModelHandle<UserStore>,
22 rpc: Arc<Client>,
23 outgoing_messages_lock: Arc<Mutex<()>>,
24 rng: StdRng,
25 _subscription: Subscription,
26}
27
28#[derive(Clone, Debug)]
29pub struct ChannelMessage {
30 pub id: ChannelMessageId,
31 pub body: String,
32 pub timestamp: OffsetDateTime,
33 pub sender: Arc<User>,
34 pub nonce: u128,
35}
36
37#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
38pub enum ChannelMessageId {
39 Saved(u64),
40 Pending(usize),
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct ChannelMessageSummary {
45 max_id: ChannelMessageId,
46 count: usize,
47}
48
49#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
50struct Count(usize);
51
52#[derive(Clone, Debug, PartialEq)]
53pub enum ChannelChatEvent {
54 MessagesUpdated {
55 old_range: Range<usize>,
56 new_count: usize,
57 },
58}
59
60pub fn init(client: &Arc<Client>) {
61 client.add_model_message_handler(ChannelChat::handle_message_sent);
62 client.add_model_message_handler(ChannelChat::handle_message_removed);
63}
64
65impl Entity for ChannelChat {
66 type Event = ChannelChatEvent;
67
68 fn release(&mut self, _: &mut AppContext) {
69 self.rpc
70 .send(proto::LeaveChannelChat {
71 channel_id: self.channel.id,
72 })
73 .log_err();
74 }
75}
76
77impl ChannelChat {
78 pub async fn new(
79 channel: Arc<Channel>,
80 user_store: ModelHandle<UserStore>,
81 client: Arc<Client>,
82 mut cx: AsyncAppContext,
83 ) -> Result<ModelHandle<Self>> {
84 let channel_id = channel.id;
85 let subscription = client.subscribe_to_entity(channel_id).unwrap();
86
87 let response = client
88 .request(proto::JoinChannelChat { channel_id })
89 .await?;
90 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
91 let loaded_all_messages = response.done;
92
93 Ok(cx.add_model(|cx| {
94 let mut this = Self {
95 channel,
96 user_store,
97 rpc: client,
98 outgoing_messages_lock: Default::default(),
99 messages: Default::default(),
100 loaded_all_messages,
101 next_pending_message_id: 0,
102 rng: StdRng::from_entropy(),
103 _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
104 };
105 this.insert_messages(messages, cx);
106 this
107 }))
108 }
109
110 pub fn channel(&self) -> &Arc<Channel> {
111 &self.channel
112 }
113
114 pub fn send_message(
115 &mut self,
116 body: String,
117 cx: &mut ModelContext<Self>,
118 ) -> Result<Task<Result<()>>> {
119 if body.is_empty() {
120 Err(anyhow!("message body can't be empty"))?;
121 }
122
123 let current_user = self
124 .user_store
125 .read(cx)
126 .current_user()
127 .ok_or_else(|| anyhow!("current_user is not present"))?;
128
129 let channel_id = self.channel.id;
130 let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
131 let nonce = self.rng.gen();
132 self.insert_messages(
133 SumTree::from_item(
134 ChannelMessage {
135 id: pending_id,
136 body: body.clone(),
137 sender: current_user,
138 timestamp: OffsetDateTime::now_utc(),
139 nonce,
140 },
141 &(),
142 ),
143 cx,
144 );
145 let user_store = self.user_store.clone();
146 let rpc = self.rpc.clone();
147 let outgoing_messages_lock = self.outgoing_messages_lock.clone();
148 Ok(cx.spawn(|this, mut cx| async move {
149 let outgoing_message_guard = outgoing_messages_lock.lock().await;
150 let request = rpc.request(proto::SendChannelMessage {
151 channel_id,
152 body,
153 nonce: Some(nonce.into()),
154 });
155 let response = request.await?;
156 drop(outgoing_message_guard);
157 let message = ChannelMessage::from_proto(
158 response.message.ok_or_else(|| anyhow!("invalid message"))?,
159 &user_store,
160 &mut cx,
161 )
162 .await?;
163 this.update(&mut cx, |this, cx| {
164 this.insert_messages(SumTree::from_item(message, &()), cx);
165 Ok(())
166 })
167 }))
168 }
169
170 pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
171 let response = self.rpc.request(proto::RemoveChannelMessage {
172 channel_id: self.channel.id,
173 message_id: id,
174 });
175 cx.spawn(|this, mut cx| async move {
176 response.await?;
177
178 this.update(&mut cx, |this, cx| {
179 this.message_removed(id, cx);
180 Ok(())
181 })
182 })
183 }
184
185 pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
186 if !self.loaded_all_messages {
187 let rpc = self.rpc.clone();
188 let user_store = self.user_store.clone();
189 let channel_id = self.channel.id;
190 if let Some(before_message_id) =
191 self.messages.first().and_then(|message| match message.id {
192 ChannelMessageId::Saved(id) => Some(id),
193 ChannelMessageId::Pending(_) => None,
194 })
195 {
196 cx.spawn(|this, mut cx| {
197 async move {
198 let response = rpc
199 .request(proto::GetChannelMessages {
200 channel_id,
201 before_message_id,
202 })
203 .await?;
204 let loaded_all_messages = response.done;
205 let messages =
206 messages_from_proto(response.messages, &user_store, &mut cx).await?;
207 this.update(&mut cx, |this, cx| {
208 this.loaded_all_messages = loaded_all_messages;
209 this.insert_messages(messages, cx);
210 });
211 anyhow::Ok(())
212 }
213 .log_err()
214 })
215 .detach();
216 return true;
217 }
218 }
219 false
220 }
221
222 pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
223 let user_store = self.user_store.clone();
224 let rpc = self.rpc.clone();
225 let channel_id = self.channel.id;
226 cx.spawn(|this, mut cx| {
227 async move {
228 let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
229 let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
230 let loaded_all_messages = response.done;
231
232 let pending_messages = this.update(&mut cx, |this, cx| {
233 if let Some((first_new_message, last_old_message)) =
234 messages.first().zip(this.messages.last())
235 {
236 if first_new_message.id > last_old_message.id {
237 let old_messages = mem::take(&mut this.messages);
238 cx.emit(ChannelChatEvent::MessagesUpdated {
239 old_range: 0..old_messages.summary().count,
240 new_count: 0,
241 });
242 this.loaded_all_messages = loaded_all_messages;
243 }
244 }
245
246 this.insert_messages(messages, cx);
247 if loaded_all_messages {
248 this.loaded_all_messages = loaded_all_messages;
249 }
250
251 this.pending_messages().cloned().collect::<Vec<_>>()
252 });
253
254 for pending_message in pending_messages {
255 let request = rpc.request(proto::SendChannelMessage {
256 channel_id,
257 body: pending_message.body,
258 nonce: Some(pending_message.nonce.into()),
259 });
260 let response = request.await?;
261 let message = ChannelMessage::from_proto(
262 response.message.ok_or_else(|| anyhow!("invalid message"))?,
263 &user_store,
264 &mut cx,
265 )
266 .await?;
267 this.update(&mut cx, |this, cx| {
268 this.insert_messages(SumTree::from_item(message, &()), cx);
269 });
270 }
271
272 anyhow::Ok(())
273 }
274 .log_err()
275 })
276 .detach();
277 }
278
279 pub fn message_count(&self) -> usize {
280 self.messages.summary().count
281 }
282
283 pub fn messages(&self) -> &SumTree<ChannelMessage> {
284 &self.messages
285 }
286
287 pub fn message(&self, ix: usize) -> &ChannelMessage {
288 let mut cursor = self.messages.cursor::<Count>();
289 cursor.seek(&Count(ix), Bias::Right, &());
290 cursor.item().unwrap()
291 }
292
293 pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
294 let mut cursor = self.messages.cursor::<Count>();
295 cursor.seek(&Count(range.start), Bias::Right, &());
296 cursor.take(range.len())
297 }
298
299 pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
300 let mut cursor = self.messages.cursor::<ChannelMessageId>();
301 cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
302 cursor
303 }
304
305 async fn handle_message_sent(
306 this: ModelHandle<Self>,
307 message: TypedEnvelope<proto::ChannelMessageSent>,
308 _: Arc<Client>,
309 mut cx: AsyncAppContext,
310 ) -> Result<()> {
311 let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
312 let message = message
313 .payload
314 .message
315 .ok_or_else(|| anyhow!("empty message"))?;
316
317 let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
318 this.update(&mut cx, |this, cx| {
319 this.insert_messages(SumTree::from_item(message, &()), cx)
320 });
321
322 Ok(())
323 }
324
325 async fn handle_message_removed(
326 this: ModelHandle<Self>,
327 message: TypedEnvelope<proto::RemoveChannelMessage>,
328 _: Arc<Client>,
329 mut cx: AsyncAppContext,
330 ) -> Result<()> {
331 this.update(&mut cx, |this, cx| {
332 this.message_removed(message.payload.message_id, cx)
333 });
334 Ok(())
335 }
336
337 fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
338 if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
339 let nonces = messages
340 .cursor::<()>()
341 .map(|m| m.nonce)
342 .collect::<HashSet<_>>();
343
344 let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
345 let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
346 let start_ix = old_cursor.start().1 .0;
347 let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
348 let removed_count = removed_messages.summary().count;
349 let new_count = messages.summary().count;
350 let end_ix = start_ix + removed_count;
351
352 new_messages.append(messages, &());
353
354 let mut ranges = Vec::<Range<usize>>::new();
355 if new_messages.last().unwrap().is_pending() {
356 new_messages.append(old_cursor.suffix(&()), &());
357 } else {
358 new_messages.append(
359 old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
360 &(),
361 );
362
363 while let Some(message) = old_cursor.item() {
364 let message_ix = old_cursor.start().1 .0;
365 if nonces.contains(&message.nonce) {
366 if ranges.last().map_or(false, |r| r.end == message_ix) {
367 ranges.last_mut().unwrap().end += 1;
368 } else {
369 ranges.push(message_ix..message_ix + 1);
370 }
371 } else {
372 new_messages.push(message.clone(), &());
373 }
374 old_cursor.next(&());
375 }
376 }
377
378 drop(old_cursor);
379 self.messages = new_messages;
380
381 for range in ranges.into_iter().rev() {
382 cx.emit(ChannelChatEvent::MessagesUpdated {
383 old_range: range,
384 new_count: 0,
385 });
386 }
387 cx.emit(ChannelChatEvent::MessagesUpdated {
388 old_range: start_ix..end_ix,
389 new_count,
390 });
391 cx.notify();
392 }
393 }
394
395 fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
396 let mut cursor = self.messages.cursor::<ChannelMessageId>();
397 let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
398 if let Some(item) = cursor.item() {
399 if item.id == ChannelMessageId::Saved(id) {
400 let ix = messages.summary().count;
401 cursor.next(&());
402 messages.append(cursor.suffix(&()), &());
403 drop(cursor);
404 self.messages = messages;
405 cx.emit(ChannelChatEvent::MessagesUpdated {
406 old_range: ix..ix + 1,
407 new_count: 0,
408 });
409 }
410 }
411 }
412}
413
414async fn messages_from_proto(
415 proto_messages: Vec<proto::ChannelMessage>,
416 user_store: &ModelHandle<UserStore>,
417 cx: &mut AsyncAppContext,
418) -> Result<SumTree<ChannelMessage>> {
419 let unique_user_ids = proto_messages
420 .iter()
421 .map(|m| m.sender_id)
422 .collect::<HashSet<_>>()
423 .into_iter()
424 .collect();
425 user_store
426 .update(cx, |user_store, cx| {
427 user_store.get_users(unique_user_ids, cx)
428 })
429 .await?;
430
431 let mut messages = Vec::with_capacity(proto_messages.len());
432 for message in proto_messages {
433 messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
434 }
435 let mut result = SumTree::new();
436 result.extend(messages, &());
437 Ok(result)
438}
439
440impl ChannelMessage {
441 pub async fn from_proto(
442 message: proto::ChannelMessage,
443 user_store: &ModelHandle<UserStore>,
444 cx: &mut AsyncAppContext,
445 ) -> Result<Self> {
446 let sender = user_store
447 .update(cx, |user_store, cx| {
448 user_store.get_user(message.sender_id, cx)
449 })
450 .await?;
451 Ok(ChannelMessage {
452 id: ChannelMessageId::Saved(message.id),
453 body: message.body,
454 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
455 sender,
456 nonce: message
457 .nonce
458 .ok_or_else(|| anyhow!("nonce is required"))?
459 .into(),
460 })
461 }
462
463 pub fn is_pending(&self) -> bool {
464 matches!(self.id, ChannelMessageId::Pending(_))
465 }
466}
467
468impl sum_tree::Item for ChannelMessage {
469 type Summary = ChannelMessageSummary;
470
471 fn summary(&self) -> Self::Summary {
472 ChannelMessageSummary {
473 max_id: self.id,
474 count: 1,
475 }
476 }
477}
478
479impl Default for ChannelMessageId {
480 fn default() -> Self {
481 Self::Saved(0)
482 }
483}
484
485impl sum_tree::Summary for ChannelMessageSummary {
486 type Context = ();
487
488 fn add_summary(&mut self, summary: &Self, _: &()) {
489 self.max_id = summary.max_id;
490 self.count += summary.count;
491 }
492}
493
494impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
495 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
496 debug_assert!(summary.max_id > *self);
497 *self = summary.max_id;
498 }
499}
500
501impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
502 fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
503 self.0 += summary.count;
504 }
505}