1mod store;
2
3use super::{
4 auth,
5 db::{ChannelId, MessageId, UserId},
6 AppState,
7};
8use crate::errors::TideResultExt;
9use anyhow::anyhow;
10use async_std::{sync::RwLock, task};
11use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
12use futures::{future::BoxFuture, FutureExt};
13use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _};
14use sha1::{Digest as _, Sha1};
15use std::{
16 any::TypeId,
17 collections::{HashMap, HashSet},
18 future::Future,
19 mem,
20 sync::Arc,
21 time::Instant,
22};
23use store::{ReplicaId, Store, Worktree};
24use surf::StatusCode;
25use tide::log;
26use tide::{
27 http::headers::{HeaderName, CONNECTION, UPGRADE},
28 Request, Response,
29};
30use time::OffsetDateTime;
31use zrpc::{
32 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
33 Connection, ConnectionId, Peer, TypedEnvelope,
34};
35
36type MessageHandler = Box<
37 dyn Send
38 + Sync
39 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
40>;
41
42pub struct Server {
43 peer: Arc<Peer>,
44 store: RwLock<Store>,
45 app_state: Arc<AppState>,
46 handlers: HashMap<TypeId, MessageHandler>,
47 notifications: Option<mpsc::Sender<()>>,
48}
49
50const MESSAGE_COUNT_PER_PAGE: usize = 100;
51const MAX_MESSAGE_LEN: usize = 1024;
52
53impl Server {
54 pub fn new(
55 app_state: Arc<AppState>,
56 peer: Arc<Peer>,
57 notifications: Option<mpsc::Sender<()>>,
58 ) -> Arc<Self> {
59 let mut server = Self {
60 peer,
61 app_state,
62 store: Default::default(),
63 handlers: Default::default(),
64 notifications,
65 };
66
67 server
68 .add_handler(Server::ping)
69 .add_handler(Server::open_worktree)
70 .add_handler(Server::close_worktree)
71 .add_handler(Server::share_worktree)
72 .add_handler(Server::unshare_worktree)
73 .add_handler(Server::join_worktree)
74 .add_handler(Server::update_worktree)
75 .add_handler(Server::open_buffer)
76 .add_handler(Server::close_buffer)
77 .add_handler(Server::update_buffer)
78 .add_handler(Server::buffer_saved)
79 .add_handler(Server::save_buffer)
80 .add_handler(Server::get_channels)
81 .add_handler(Server::get_users)
82 .add_handler(Server::join_channel)
83 .add_handler(Server::leave_channel)
84 .add_handler(Server::send_channel_message)
85 .add_handler(Server::get_channel_messages);
86
87 Arc::new(server)
88 }
89
90 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
91 where
92 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
93 Fut: 'static + Send + Future<Output = tide::Result<()>>,
94 M: EnvelopedMessage,
95 {
96 let prev_handler = self.handlers.insert(
97 TypeId::of::<M>(),
98 Box::new(move |server, envelope| {
99 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
100 (handler)(server, *envelope).boxed()
101 }),
102 );
103 if prev_handler.is_some() {
104 panic!("registered a handler for the same message twice");
105 }
106 self
107 }
108
109 pub fn handle_connection(
110 self: &Arc<Self>,
111 connection: Connection,
112 addr: String,
113 user_id: UserId,
114 ) -> impl Future<Output = ()> {
115 let this = self.clone();
116 async move {
117 let (connection_id, handle_io, mut incoming_rx) =
118 this.peer.add_connection(connection).await;
119 this.store
120 .write()
121 .await
122 .add_connection(connection_id, user_id);
123 if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
124 log::error!("error updating collaborators for {:?}: {}", user_id, err);
125 }
126
127 let handle_io = handle_io.fuse();
128 futures::pin_mut!(handle_io);
129 loop {
130 let next_message = incoming_rx.recv().fuse();
131 futures::pin_mut!(next_message);
132 futures::select_biased! {
133 message = next_message => {
134 if let Some(message) = message {
135 let start_time = Instant::now();
136 log::info!("RPC message received: {}", message.payload_type_name());
137 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
138 if let Err(err) = (handler)(this.clone(), message).await {
139 log::error!("error handling message: {:?}", err);
140 } else {
141 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
142 }
143
144 if let Some(mut notifications) = this.notifications.clone() {
145 let _ = notifications.send(()).await;
146 }
147 } else {
148 log::warn!("unhandled message: {}", message.payload_type_name());
149 }
150 } else {
151 log::info!("rpc connection closed {:?}", addr);
152 break;
153 }
154 }
155 handle_io = handle_io => {
156 if let Err(err) = handle_io {
157 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
158 }
159 break;
160 }
161 }
162 }
163
164 if let Err(err) = this.sign_out(connection_id).await {
165 log::error!("error signing out connection {:?} - {:?}", addr, err);
166 }
167 }
168 }
169
170 async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
171 self.peer.disconnect(connection_id).await;
172 let removed_connection = self.store.write().await.remove_connection(connection_id)?;
173
174 for (worktree_id, worktree) in removed_connection.hosted_worktrees {
175 if let Some(share) = worktree.share {
176 broadcast(
177 connection_id,
178 share.guest_connection_ids.keys().copied().collect(),
179 |conn_id| {
180 self.peer
181 .send(conn_id, proto::UnshareWorktree { worktree_id })
182 },
183 )
184 .await?;
185 }
186 }
187
188 for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids {
189 broadcast(connection_id, peer_ids, |conn_id| {
190 self.peer.send(
191 conn_id,
192 proto::RemovePeer {
193 worktree_id,
194 peer_id: connection_id.0,
195 },
196 )
197 })
198 .await?;
199 }
200
201 self.update_collaborators_for_users(removed_connection.collaborator_ids.iter())
202 .await;
203
204 Ok(())
205 }
206
207 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
208 self.peer.respond(request.receipt(), proto::Ack {}).await?;
209 Ok(())
210 }
211
212 async fn open_worktree(
213 self: Arc<Server>,
214 request: TypedEnvelope<proto::OpenWorktree>,
215 ) -> tide::Result<()> {
216 let receipt = request.receipt();
217 let host_user_id = self
218 .store
219 .read()
220 .await
221 .user_id_for_connection(request.sender_id)?;
222
223 let mut collaborator_user_ids = HashSet::new();
224 collaborator_user_ids.insert(host_user_id);
225 for github_login in request.payload.collaborator_logins {
226 match self.app_state.db.create_user(&github_login, false).await {
227 Ok(collaborator_user_id) => {
228 collaborator_user_ids.insert(collaborator_user_id);
229 }
230 Err(err) => {
231 let message = err.to_string();
232 self.peer
233 .respond_with_error(receipt, proto::Error { message })
234 .await?;
235 return Ok(());
236 }
237 }
238 }
239
240 let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
241 let worktree_id = self.store.write().await.add_worktree(Worktree {
242 host_connection_id: request.sender_id,
243 collaborator_user_ids: collaborator_user_ids.clone(),
244 root_name: request.payload.root_name,
245 share: None,
246 });
247
248 self.peer
249 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
250 .await?;
251 self.update_collaborators_for_users(&collaborator_user_ids)
252 .await?;
253
254 Ok(())
255 }
256
257 async fn close_worktree(
258 self: Arc<Server>,
259 request: TypedEnvelope<proto::CloseWorktree>,
260 ) -> tide::Result<()> {
261 let worktree_id = request.payload.worktree_id;
262 let worktree = self
263 .store
264 .write()
265 .await
266 .remove_worktree(worktree_id, request.sender_id)?;
267
268 if let Some(share) = worktree.share {
269 broadcast(
270 request.sender_id,
271 share.guest_connection_ids.keys().copied().collect(),
272 |conn_id| {
273 self.peer
274 .send(conn_id, proto::UnshareWorktree { worktree_id })
275 },
276 )
277 .await?;
278 }
279 self.update_collaborators_for_users(&worktree.collaborator_user_ids)
280 .await?;
281 Ok(())
282 }
283
284 async fn share_worktree(
285 self: Arc<Server>,
286 mut request: TypedEnvelope<proto::ShareWorktree>,
287 ) -> tide::Result<()> {
288 let worktree = request
289 .payload
290 .worktree
291 .as_mut()
292 .ok_or_else(|| anyhow!("missing worktree"))?;
293 let entries = mem::take(&mut worktree.entries)
294 .into_iter()
295 .map(|entry| (entry.id, entry))
296 .collect();
297
298 if let Some(collaborator_user_ids) =
299 self.store
300 .write()
301 .await
302 .share_worktree(worktree.id, request.sender_id, entries)
303 {
304 self.peer
305 .respond(request.receipt(), proto::ShareWorktreeResponse {})
306 .await?;
307 self.update_collaborators_for_users(&collaborator_user_ids)
308 .await?;
309 } else {
310 self.peer
311 .respond_with_error(
312 request.receipt(),
313 proto::Error {
314 message: "no such worktree".to_string(),
315 },
316 )
317 .await?;
318 }
319 Ok(())
320 }
321
322 async fn unshare_worktree(
323 self: Arc<Server>,
324 request: TypedEnvelope<proto::UnshareWorktree>,
325 ) -> tide::Result<()> {
326 let worktree_id = request.payload.worktree_id;
327 let (connection_ids, collaborator_user_ids) = self
328 .store
329 .write()
330 .await
331 .unshare_worktree(worktree_id, request.sender_id)?;
332
333 broadcast(request.sender_id, connection_ids, |conn_id| {
334 self.peer
335 .send(conn_id, proto::UnshareWorktree { worktree_id })
336 })
337 .await?;
338 self.update_collaborators_for_users(&collaborator_user_ids)
339 .await?;
340
341 Ok(())
342 }
343
344 async fn join_worktree(
345 self: Arc<Server>,
346 request: TypedEnvelope<proto::JoinWorktree>,
347 ) -> tide::Result<()> {
348 let worktree_id = request.payload.worktree_id;
349 let user_id = self
350 .store
351 .read()
352 .await
353 .user_id_for_connection(request.sender_id)?;
354
355 let response;
356 let connection_ids;
357 let collaborator_user_ids;
358 let mut state = self.store.write().await;
359 match state.join_worktree(request.sender_id, user_id, worktree_id) {
360 Ok((peer_replica_id, worktree)) => {
361 let share = worktree.share()?;
362 let peer_count = share.guest_connection_ids.len();
363 let mut peers = Vec::with_capacity(peer_count);
364 peers.push(proto::Peer {
365 peer_id: worktree.host_connection_id.0,
366 replica_id: 0,
367 });
368 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
369 if *peer_conn_id != request.sender_id {
370 peers.push(proto::Peer {
371 peer_id: peer_conn_id.0,
372 replica_id: *peer_replica_id as u32,
373 });
374 }
375 }
376 response = proto::JoinWorktreeResponse {
377 worktree: Some(proto::Worktree {
378 id: worktree_id,
379 root_name: worktree.root_name.clone(),
380 entries: share.entries.values().cloned().collect(),
381 }),
382 replica_id: peer_replica_id as u32,
383 peers,
384 };
385 connection_ids = worktree.connection_ids();
386 collaborator_user_ids = worktree.collaborator_user_ids.clone();
387 }
388 Err(error) => {
389 self.peer
390 .respond_with_error(
391 request.receipt(),
392 proto::Error {
393 message: error.to_string(),
394 },
395 )
396 .await?;
397 return Ok(());
398 }
399 }
400
401 drop(state);
402 broadcast(request.sender_id, connection_ids, |conn_id| {
403 self.peer.send(
404 conn_id,
405 proto::AddPeer {
406 worktree_id,
407 peer: Some(proto::Peer {
408 peer_id: request.sender_id.0,
409 replica_id: response.replica_id,
410 }),
411 },
412 )
413 })
414 .await?;
415 self.peer.respond(request.receipt(), response).await?;
416 self.update_collaborators_for_users(&collaborator_user_ids)
417 .await?;
418
419 Ok(())
420 }
421
422 async fn leave_worktree(
423 self: &Arc<Server>,
424 worktree_id: u64,
425 sender_conn_id: ConnectionId,
426 ) -> tide::Result<()> {
427 if let Some((connection_ids, collaborator_ids)) = self
428 .store
429 .write()
430 .await
431 .leave_worktree(sender_conn_id, worktree_id)
432 {
433 broadcast(sender_conn_id, connection_ids, |conn_id| {
434 self.peer.send(
435 conn_id,
436 proto::RemovePeer {
437 worktree_id,
438 peer_id: sender_conn_id.0,
439 },
440 )
441 })
442 .await?;
443 self.update_collaborators_for_users(&collaborator_ids)
444 .await?;
445 }
446 Ok(())
447 }
448
449 async fn update_worktree(
450 self: Arc<Server>,
451 request: TypedEnvelope<proto::UpdateWorktree>,
452 ) -> tide::Result<()> {
453 let connection_ids = self.store.write().await.update_worktree(
454 request.sender_id,
455 request.payload.worktree_id,
456 &request.payload.removed_entries,
457 &request.payload.updated_entries,
458 )?;
459
460 broadcast(request.sender_id, connection_ids, |connection_id| {
461 self.peer
462 .forward_send(request.sender_id, connection_id, request.payload.clone())
463 })
464 .await?;
465
466 Ok(())
467 }
468
469 async fn open_buffer(
470 self: Arc<Server>,
471 request: TypedEnvelope<proto::OpenBuffer>,
472 ) -> tide::Result<()> {
473 let receipt = request.receipt();
474 let host_connection_id = self
475 .store
476 .read()
477 .await
478 .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
479 let response = self
480 .peer
481 .forward_request(request.sender_id, host_connection_id, request.payload)
482 .await?;
483 self.peer.respond(receipt, response).await?;
484 Ok(())
485 }
486
487 async fn close_buffer(
488 self: Arc<Server>,
489 request: TypedEnvelope<proto::CloseBuffer>,
490 ) -> tide::Result<()> {
491 let host_connection_id = self
492 .store
493 .read()
494 .await
495 .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
496 self.peer
497 .forward_send(request.sender_id, host_connection_id, request.payload)
498 .await?;
499 Ok(())
500 }
501
502 async fn save_buffer(
503 self: Arc<Server>,
504 request: TypedEnvelope<proto::SaveBuffer>,
505 ) -> tide::Result<()> {
506 let host;
507 let guests;
508 {
509 let state = self.store.read().await;
510 host = state
511 .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
512 guests = state
513 .worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?;
514 }
515
516 let sender = request.sender_id;
517 let receipt = request.receipt();
518 let response = self
519 .peer
520 .forward_request(sender, host, request.payload.clone())
521 .await?;
522
523 broadcast(host, guests, |conn_id| {
524 let response = response.clone();
525 let peer = &self.peer;
526 async move {
527 if conn_id == sender {
528 peer.respond(receipt, response).await
529 } else {
530 peer.forward_send(host, conn_id, response).await
531 }
532 }
533 })
534 .await?;
535
536 Ok(())
537 }
538
539 async fn update_buffer(
540 self: Arc<Server>,
541 request: TypedEnvelope<proto::UpdateBuffer>,
542 ) -> tide::Result<()> {
543 broadcast(
544 request.sender_id,
545 self.store
546 .read()
547 .await
548 .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
549 |connection_id| {
550 self.peer
551 .forward_send(request.sender_id, connection_id, request.payload.clone())
552 },
553 )
554 .await?;
555 self.peer.respond(request.receipt(), proto::Ack {}).await?;
556 Ok(())
557 }
558
559 async fn buffer_saved(
560 self: Arc<Server>,
561 request: TypedEnvelope<proto::BufferSaved>,
562 ) -> tide::Result<()> {
563 broadcast(
564 request.sender_id,
565 self.store
566 .read()
567 .await
568 .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
569 |connection_id| {
570 self.peer
571 .forward_send(request.sender_id, connection_id, request.payload.clone())
572 },
573 )
574 .await?;
575 Ok(())
576 }
577
578 async fn get_channels(
579 self: Arc<Server>,
580 request: TypedEnvelope<proto::GetChannels>,
581 ) -> tide::Result<()> {
582 let user_id = self
583 .store
584 .read()
585 .await
586 .user_id_for_connection(request.sender_id)?;
587 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
588 self.peer
589 .respond(
590 request.receipt(),
591 proto::GetChannelsResponse {
592 channels: channels
593 .into_iter()
594 .map(|chan| proto::Channel {
595 id: chan.id.to_proto(),
596 name: chan.name,
597 })
598 .collect(),
599 },
600 )
601 .await?;
602 Ok(())
603 }
604
605 async fn get_users(
606 self: Arc<Server>,
607 request: TypedEnvelope<proto::GetUsers>,
608 ) -> tide::Result<()> {
609 let receipt = request.receipt();
610 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
611 let users = self
612 .app_state
613 .db
614 .get_users_by_ids(user_ids)
615 .await?
616 .into_iter()
617 .map(|user| proto::User {
618 id: user.id.to_proto(),
619 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
620 github_login: user.github_login,
621 })
622 .collect();
623 self.peer
624 .respond(receipt, proto::GetUsersResponse { users })
625 .await?;
626 Ok(())
627 }
628
629 async fn update_collaborators_for_users<'a>(
630 self: &Arc<Server>,
631 user_ids: impl IntoIterator<Item = &'a UserId>,
632 ) -> tide::Result<()> {
633 let mut send_futures = Vec::new();
634
635 let state = self.store.read().await;
636 for user_id in user_ids {
637 let collaborators = state.collaborators_for_user(*user_id);
638 for connection_id in state.connection_ids_for_user(*user_id) {
639 send_futures.push(self.peer.send(
640 connection_id,
641 proto::UpdateCollaborators {
642 collaborators: collaborators.clone(),
643 },
644 ));
645 }
646 }
647
648 drop(state);
649 futures::future::try_join_all(send_futures).await?;
650
651 Ok(())
652 }
653
654 async fn join_channel(
655 self: Arc<Self>,
656 request: TypedEnvelope<proto::JoinChannel>,
657 ) -> tide::Result<()> {
658 let user_id = self
659 .store
660 .read()
661 .await
662 .user_id_for_connection(request.sender_id)?;
663 let channel_id = ChannelId::from_proto(request.payload.channel_id);
664 if !self
665 .app_state
666 .db
667 .can_user_access_channel(user_id, channel_id)
668 .await?
669 {
670 Err(anyhow!("access denied"))?;
671 }
672
673 self.store
674 .write()
675 .await
676 .join_channel(request.sender_id, channel_id);
677 let messages = self
678 .app_state
679 .db
680 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
681 .await?
682 .into_iter()
683 .map(|msg| proto::ChannelMessage {
684 id: msg.id.to_proto(),
685 body: msg.body,
686 timestamp: msg.sent_at.unix_timestamp() as u64,
687 sender_id: msg.sender_id.to_proto(),
688 nonce: Some(msg.nonce.as_u128().into()),
689 })
690 .collect::<Vec<_>>();
691 self.peer
692 .respond(
693 request.receipt(),
694 proto::JoinChannelResponse {
695 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
696 messages,
697 },
698 )
699 .await?;
700 Ok(())
701 }
702
703 async fn leave_channel(
704 self: Arc<Self>,
705 request: TypedEnvelope<proto::LeaveChannel>,
706 ) -> tide::Result<()> {
707 let user_id = self
708 .store
709 .read()
710 .await
711 .user_id_for_connection(request.sender_id)?;
712 let channel_id = ChannelId::from_proto(request.payload.channel_id);
713 if !self
714 .app_state
715 .db
716 .can_user_access_channel(user_id, channel_id)
717 .await?
718 {
719 Err(anyhow!("access denied"))?;
720 }
721
722 self.store
723 .write()
724 .await
725 .leave_channel(request.sender_id, channel_id);
726
727 Ok(())
728 }
729
730 async fn send_channel_message(
731 self: Arc<Self>,
732 request: TypedEnvelope<proto::SendChannelMessage>,
733 ) -> tide::Result<()> {
734 let receipt = request.receipt();
735 let channel_id = ChannelId::from_proto(request.payload.channel_id);
736 let user_id;
737 let connection_ids;
738 {
739 let state = self.store.read().await;
740 user_id = state.user_id_for_connection(request.sender_id)?;
741 if let Some(ids) = state.channel_connection_ids(channel_id) {
742 connection_ids = ids;
743 } else {
744 return Ok(());
745 }
746 }
747
748 // Validate the message body.
749 let body = request.payload.body.trim().to_string();
750 if body.len() > MAX_MESSAGE_LEN {
751 self.peer
752 .respond_with_error(
753 receipt,
754 proto::Error {
755 message: "message is too long".to_string(),
756 },
757 )
758 .await?;
759 return Ok(());
760 }
761 if body.is_empty() {
762 self.peer
763 .respond_with_error(
764 receipt,
765 proto::Error {
766 message: "message can't be blank".to_string(),
767 },
768 )
769 .await?;
770 return Ok(());
771 }
772
773 let timestamp = OffsetDateTime::now_utc();
774 let nonce = if let Some(nonce) = request.payload.nonce {
775 nonce
776 } else {
777 self.peer
778 .respond_with_error(
779 receipt,
780 proto::Error {
781 message: "nonce can't be blank".to_string(),
782 },
783 )
784 .await?;
785 return Ok(());
786 };
787
788 let message_id = self
789 .app_state
790 .db
791 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
792 .await?
793 .to_proto();
794 let message = proto::ChannelMessage {
795 sender_id: user_id.to_proto(),
796 id: message_id,
797 body,
798 timestamp: timestamp.unix_timestamp() as u64,
799 nonce: Some(nonce),
800 };
801 broadcast(request.sender_id, connection_ids, |conn_id| {
802 self.peer.send(
803 conn_id,
804 proto::ChannelMessageSent {
805 channel_id: channel_id.to_proto(),
806 message: Some(message.clone()),
807 },
808 )
809 })
810 .await?;
811 self.peer
812 .respond(
813 receipt,
814 proto::SendChannelMessageResponse {
815 message: Some(message),
816 },
817 )
818 .await?;
819 Ok(())
820 }
821
822 async fn get_channel_messages(
823 self: Arc<Self>,
824 request: TypedEnvelope<proto::GetChannelMessages>,
825 ) -> tide::Result<()> {
826 let user_id = self
827 .store
828 .read()
829 .await
830 .user_id_for_connection(request.sender_id)?;
831 let channel_id = ChannelId::from_proto(request.payload.channel_id);
832 if !self
833 .app_state
834 .db
835 .can_user_access_channel(user_id, channel_id)
836 .await?
837 {
838 Err(anyhow!("access denied"))?;
839 }
840
841 let messages = self
842 .app_state
843 .db
844 .get_channel_messages(
845 channel_id,
846 MESSAGE_COUNT_PER_PAGE,
847 Some(MessageId::from_proto(request.payload.before_message_id)),
848 )
849 .await?
850 .into_iter()
851 .map(|msg| proto::ChannelMessage {
852 id: msg.id.to_proto(),
853 body: msg.body,
854 timestamp: msg.sent_at.unix_timestamp() as u64,
855 sender_id: msg.sender_id.to_proto(),
856 nonce: Some(msg.nonce.as_u128().into()),
857 })
858 .collect::<Vec<_>>();
859 self.peer
860 .respond(
861 request.receipt(),
862 proto::GetChannelMessagesResponse {
863 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
864 messages,
865 },
866 )
867 .await?;
868 Ok(())
869 }
870}
871
872pub async fn broadcast<F, T>(
873 sender_id: ConnectionId,
874 receiver_ids: Vec<ConnectionId>,
875 mut f: F,
876) -> anyhow::Result<()>
877where
878 F: FnMut(ConnectionId) -> T,
879 T: Future<Output = anyhow::Result<()>>,
880{
881 let futures = receiver_ids
882 .into_iter()
883 .filter(|id| *id != sender_id)
884 .map(|id| f(id));
885 futures::future::try_join_all(futures).await?;
886 Ok(())
887}
888
889pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
890 let server = Server::new(app.state().clone(), rpc.clone(), None);
891 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
892 let user_id = request.ext::<UserId>().copied();
893 let server = server.clone();
894 async move {
895 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
896
897 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
898 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
899 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
900
901 if !upgrade_requested {
902 return Ok(Response::new(StatusCode::UpgradeRequired));
903 }
904
905 let header = match request.header("Sec-Websocket-Key") {
906 Some(h) => h.as_str(),
907 None => return Err(anyhow!("expected sec-websocket-key"))?,
908 };
909
910 let mut response = Response::new(StatusCode::SwitchingProtocols);
911 response.insert_header(UPGRADE, "websocket");
912 response.insert_header(CONNECTION, "Upgrade");
913 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
914 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
915 response.insert_header("Sec-Websocket-Version", "13");
916
917 let http_res: &mut tide::http::Response = response.as_mut();
918 let upgrade_receiver = http_res.recv_upgrade().await;
919 let addr = request.remote().unwrap_or("unknown").to_string();
920 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
921 task::spawn(async move {
922 if let Some(stream) = upgrade_receiver.await {
923 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
924 }
925 });
926
927 Ok(response)
928 }
929 });
930}
931
932fn header_contains_ignore_case<T>(
933 request: &tide::Request<T>,
934 header_name: HeaderName,
935 value: &str,
936) -> bool {
937 request
938 .header(header_name)
939 .map(|h| {
940 h.as_str()
941 .split(',')
942 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
943 })
944 .unwrap_or(false)
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use crate::{
951 auth,
952 db::{tests::TestDb, UserId},
953 github, AppState, Config,
954 };
955 use async_std::{sync::RwLockReadGuard, task};
956 use gpui::{ModelHandle, TestAppContext};
957 use parking_lot::Mutex;
958 use postage::{mpsc, watch};
959 use serde_json::json;
960 use sqlx::types::time::OffsetDateTime;
961 use std::{
962 path::Path,
963 sync::{
964 atomic::{AtomicBool, Ordering::SeqCst},
965 Arc,
966 },
967 time::Duration,
968 };
969 use zed::{
970 channel::{Channel, ChannelDetails, ChannelList},
971 editor::{Editor, Insert},
972 fs::{FakeFs, Fs as _},
973 language::LanguageRegistry,
974 rpc::{self, Client, Credentials, EstablishConnectionError},
975 settings,
976 test::FakeHttpClient,
977 user::UserStore,
978 worktree::Worktree,
979 };
980 use zrpc::Peer;
981
982 #[gpui::test]
983 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
984 let (window_b, _) = cx_b.add_window(|_| EmptyView);
985 let settings = cx_b.read(settings::test).1;
986 let lang_registry = Arc::new(LanguageRegistry::new());
987
988 // Connect to a server as 2 clients.
989 let mut server = TestServer::start().await;
990 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
991 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
992
993 cx_a.foreground().forbid_parking();
994
995 // Share a local worktree as client A
996 let fs = Arc::new(FakeFs::new());
997 fs.insert_tree(
998 "/a",
999 json!({
1000 ".zed.toml": r#"collaborators = ["user_b"]"#,
1001 "a.txt": "a-contents",
1002 "b.txt": "b-contents",
1003 }),
1004 )
1005 .await;
1006 let worktree_a = Worktree::open_local(
1007 client_a.clone(),
1008 "/a".as_ref(),
1009 fs,
1010 lang_registry.clone(),
1011 &mut cx_a.to_async(),
1012 )
1013 .await
1014 .unwrap();
1015 worktree_a
1016 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1017 .await;
1018 let worktree_id = worktree_a
1019 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1020 .await
1021 .unwrap();
1022
1023 // Join that worktree as client B, and see that a guest has joined as client A.
1024 let worktree_b = Worktree::open_remote(
1025 client_b.clone(),
1026 worktree_id,
1027 lang_registry.clone(),
1028 &mut cx_b.to_async(),
1029 )
1030 .await
1031 .unwrap();
1032 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1033 worktree_a
1034 .condition(&cx_a, |tree, _| {
1035 tree.peers()
1036 .values()
1037 .any(|replica_id| *replica_id == replica_id_b)
1038 })
1039 .await;
1040
1041 // Open the same file as client B and client A.
1042 let buffer_b = worktree_b
1043 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1044 .await
1045 .unwrap();
1046 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1047 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1048 let buffer_a = worktree_a
1049 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1050 .await
1051 .unwrap();
1052
1053 // Create a selection set as client B and see that selection set as client A.
1054 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1055 buffer_a
1056 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1057 .await;
1058
1059 // Edit the buffer as client B and see that edit as client A.
1060 editor_b.update(&mut cx_b, |editor, cx| {
1061 editor.insert(&Insert("ok, ".into()), cx)
1062 });
1063 buffer_a
1064 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1065 .await;
1066
1067 // Remove the selection set as client B, see those selections disappear as client A.
1068 cx_b.update(move |_| drop(editor_b));
1069 buffer_a
1070 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1071 .await;
1072
1073 // Close the buffer as client A, see that the buffer is closed.
1074 cx_a.update(move |_| drop(buffer_a));
1075 worktree_a
1076 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1077 .await;
1078
1079 // Dropping the worktree removes client B from client A's peers.
1080 cx_b.update(move |_| drop(worktree_b));
1081 worktree_a
1082 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1083 .await;
1084 }
1085
1086 #[gpui::test]
1087 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1088 mut cx_a: TestAppContext,
1089 mut cx_b: TestAppContext,
1090 mut cx_c: TestAppContext,
1091 ) {
1092 cx_a.foreground().forbid_parking();
1093 let lang_registry = Arc::new(LanguageRegistry::new());
1094
1095 // Connect to a server as 3 clients.
1096 let mut server = TestServer::start().await;
1097 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1098 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1099 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1100
1101 let fs = Arc::new(FakeFs::new());
1102
1103 // Share a worktree as client A.
1104 fs.insert_tree(
1105 "/a",
1106 json!({
1107 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1108 "file1": "",
1109 "file2": ""
1110 }),
1111 )
1112 .await;
1113
1114 let worktree_a = Worktree::open_local(
1115 client_a.clone(),
1116 "/a".as_ref(),
1117 fs.clone(),
1118 lang_registry.clone(),
1119 &mut cx_a.to_async(),
1120 )
1121 .await
1122 .unwrap();
1123 worktree_a
1124 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1125 .await;
1126 let worktree_id = worktree_a
1127 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1128 .await
1129 .unwrap();
1130
1131 // Join that worktree as clients B and C.
1132 let worktree_b = Worktree::open_remote(
1133 client_b.clone(),
1134 worktree_id,
1135 lang_registry.clone(),
1136 &mut cx_b.to_async(),
1137 )
1138 .await
1139 .unwrap();
1140 let worktree_c = Worktree::open_remote(
1141 client_c.clone(),
1142 worktree_id,
1143 lang_registry.clone(),
1144 &mut cx_c.to_async(),
1145 )
1146 .await
1147 .unwrap();
1148
1149 // Open and edit a buffer as both guests B and C.
1150 let buffer_b = worktree_b
1151 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1152 .await
1153 .unwrap();
1154 let buffer_c = worktree_c
1155 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1156 .await
1157 .unwrap();
1158 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1159 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1160
1161 // Open and edit that buffer as the host.
1162 let buffer_a = worktree_a
1163 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1164 .await
1165 .unwrap();
1166
1167 buffer_a
1168 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1169 .await;
1170 buffer_a.update(&mut cx_a, |buf, cx| {
1171 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1172 });
1173
1174 // Wait for edits to propagate
1175 buffer_a
1176 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1177 .await;
1178 buffer_b
1179 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1180 .await;
1181 buffer_c
1182 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1183 .await;
1184
1185 // Edit the buffer as the host and concurrently save as guest B.
1186 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1187 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1188 save_b.await.unwrap();
1189 assert_eq!(
1190 fs.load("/a/file1".as_ref()).await.unwrap(),
1191 "hi-a, i-am-c, i-am-b, i-am-a"
1192 );
1193 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1194 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1195 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1196
1197 // Make changes on host's file system, see those changes on the guests.
1198 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1199 .await
1200 .unwrap();
1201 fs.insert_file(Path::new("/a/file4"), "4".into())
1202 .await
1203 .unwrap();
1204
1205 worktree_b
1206 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1207 .await;
1208 worktree_c
1209 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1210 .await;
1211 worktree_b.read_with(&cx_b, |tree, _| {
1212 assert_eq!(
1213 tree.paths()
1214 .map(|p| p.to_string_lossy())
1215 .collect::<Vec<_>>(),
1216 &[".zed.toml", "file1", "file3", "file4"]
1217 )
1218 });
1219 worktree_c.read_with(&cx_c, |tree, _| {
1220 assert_eq!(
1221 tree.paths()
1222 .map(|p| p.to_string_lossy())
1223 .collect::<Vec<_>>(),
1224 &[".zed.toml", "file1", "file3", "file4"]
1225 )
1226 });
1227 }
1228
1229 #[gpui::test]
1230 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1231 cx_a.foreground().forbid_parking();
1232 let lang_registry = Arc::new(LanguageRegistry::new());
1233
1234 // Connect to a server as 2 clients.
1235 let mut server = TestServer::start().await;
1236 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1237 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1238
1239 // Share a local worktree as client A
1240 let fs = Arc::new(FakeFs::new());
1241 fs.insert_tree(
1242 "/dir",
1243 json!({
1244 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1245 "a.txt": "a-contents",
1246 }),
1247 )
1248 .await;
1249
1250 let worktree_a = Worktree::open_local(
1251 client_a.clone(),
1252 "/dir".as_ref(),
1253 fs,
1254 lang_registry.clone(),
1255 &mut cx_a.to_async(),
1256 )
1257 .await
1258 .unwrap();
1259 worktree_a
1260 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1261 .await;
1262 let worktree_id = worktree_a
1263 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1264 .await
1265 .unwrap();
1266
1267 // Join that worktree as client B, and see that a guest has joined as client A.
1268 let worktree_b = Worktree::open_remote(
1269 client_b.clone(),
1270 worktree_id,
1271 lang_registry.clone(),
1272 &mut cx_b.to_async(),
1273 )
1274 .await
1275 .unwrap();
1276
1277 let buffer_b = worktree_b
1278 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1279 .await
1280 .unwrap();
1281 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1282
1283 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1284 buffer_b.read_with(&cx_b, |buf, _| {
1285 assert!(buf.is_dirty());
1286 assert!(!buf.has_conflict());
1287 });
1288
1289 buffer_b
1290 .update(&mut cx_b, |buf, cx| buf.save(cx))
1291 .unwrap()
1292 .await
1293 .unwrap();
1294 worktree_b
1295 .condition(&cx_b, |_, cx| {
1296 buffer_b.read(cx).file().unwrap().mtime != mtime
1297 })
1298 .await;
1299 buffer_b.read_with(&cx_b, |buf, _| {
1300 assert!(!buf.is_dirty());
1301 assert!(!buf.has_conflict());
1302 });
1303
1304 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1305 buffer_b.read_with(&cx_b, |buf, _| {
1306 assert!(buf.is_dirty());
1307 assert!(!buf.has_conflict());
1308 });
1309 }
1310
1311 #[gpui::test]
1312 async fn test_editing_while_guest_opens_buffer(
1313 mut cx_a: TestAppContext,
1314 mut cx_b: TestAppContext,
1315 ) {
1316 cx_a.foreground().forbid_parking();
1317 let lang_registry = Arc::new(LanguageRegistry::new());
1318
1319 // Connect to a server as 2 clients.
1320 let mut server = TestServer::start().await;
1321 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1322 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1323
1324 // Share a local worktree as client A
1325 let fs = Arc::new(FakeFs::new());
1326 fs.insert_tree(
1327 "/dir",
1328 json!({
1329 ".zed.toml": r#"collaborators = ["user_b"]"#,
1330 "a.txt": "a-contents",
1331 }),
1332 )
1333 .await;
1334 let worktree_a = Worktree::open_local(
1335 client_a.clone(),
1336 "/dir".as_ref(),
1337 fs,
1338 lang_registry.clone(),
1339 &mut cx_a.to_async(),
1340 )
1341 .await
1342 .unwrap();
1343 worktree_a
1344 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1345 .await;
1346 let worktree_id = worktree_a
1347 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1348 .await
1349 .unwrap();
1350
1351 // Join that worktree as client B, and see that a guest has joined as client A.
1352 let worktree_b = Worktree::open_remote(
1353 client_b.clone(),
1354 worktree_id,
1355 lang_registry.clone(),
1356 &mut cx_b.to_async(),
1357 )
1358 .await
1359 .unwrap();
1360
1361 let buffer_a = worktree_a
1362 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1363 .await
1364 .unwrap();
1365 let buffer_b = cx_b
1366 .background()
1367 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1368
1369 task::yield_now().await;
1370 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1371
1372 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1373 let buffer_b = buffer_b.await.unwrap();
1374 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1375 }
1376
1377 #[gpui::test]
1378 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1379 cx_a.foreground().forbid_parking();
1380 let lang_registry = Arc::new(LanguageRegistry::new());
1381
1382 // Connect to a server as 2 clients.
1383 let mut server = TestServer::start().await;
1384 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1385 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1386
1387 // Share a local worktree as client A
1388 let fs = Arc::new(FakeFs::new());
1389 fs.insert_tree(
1390 "/a",
1391 json!({
1392 ".zed.toml": r#"collaborators = ["user_b"]"#,
1393 "a.txt": "a-contents",
1394 "b.txt": "b-contents",
1395 }),
1396 )
1397 .await;
1398 let worktree_a = Worktree::open_local(
1399 client_a.clone(),
1400 "/a".as_ref(),
1401 fs,
1402 lang_registry.clone(),
1403 &mut cx_a.to_async(),
1404 )
1405 .await
1406 .unwrap();
1407 worktree_a
1408 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1409 .await;
1410 let worktree_id = worktree_a
1411 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1412 .await
1413 .unwrap();
1414
1415 // Join that worktree as client B, and see that a guest has joined as client A.
1416 let _worktree_b = Worktree::open_remote(
1417 client_b.clone(),
1418 worktree_id,
1419 lang_registry.clone(),
1420 &mut cx_b.to_async(),
1421 )
1422 .await
1423 .unwrap();
1424 worktree_a
1425 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1426 .await;
1427
1428 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1429 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1430 worktree_a
1431 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1432 .await;
1433 }
1434
1435 #[gpui::test]
1436 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1437 cx_a.foreground().forbid_parking();
1438
1439 // Connect to a server as 2 clients.
1440 let mut server = TestServer::start().await;
1441 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1442 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1443
1444 // Create an org that includes these 2 users.
1445 let db = &server.app_state.db;
1446 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1447 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1448 .await
1449 .unwrap();
1450 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1451 .await
1452 .unwrap();
1453
1454 // Create a channel that includes all the users.
1455 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1456 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1457 .await
1458 .unwrap();
1459 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1460 .await
1461 .unwrap();
1462 db.create_channel_message(
1463 channel_id,
1464 current_user_id(&user_store_b, &cx_b),
1465 "hello A, it's B.",
1466 OffsetDateTime::now_utc(),
1467 1,
1468 )
1469 .await
1470 .unwrap();
1471
1472 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1473 channels_a
1474 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1475 .await;
1476 channels_a.read_with(&cx_a, |list, _| {
1477 assert_eq!(
1478 list.available_channels().unwrap(),
1479 &[ChannelDetails {
1480 id: channel_id.to_proto(),
1481 name: "test-channel".to_string()
1482 }]
1483 )
1484 });
1485 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1486 this.get_channel(channel_id.to_proto(), cx).unwrap()
1487 });
1488 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1489 channel_a
1490 .condition(&cx_a, |channel, _| {
1491 channel_messages(channel)
1492 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1493 })
1494 .await;
1495
1496 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1497 channels_b
1498 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1499 .await;
1500 channels_b.read_with(&cx_b, |list, _| {
1501 assert_eq!(
1502 list.available_channels().unwrap(),
1503 &[ChannelDetails {
1504 id: channel_id.to_proto(),
1505 name: "test-channel".to_string()
1506 }]
1507 )
1508 });
1509
1510 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1511 this.get_channel(channel_id.to_proto(), cx).unwrap()
1512 });
1513 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1514 channel_b
1515 .condition(&cx_b, |channel, _| {
1516 channel_messages(channel)
1517 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1518 })
1519 .await;
1520
1521 channel_a
1522 .update(&mut cx_a, |channel, cx| {
1523 channel
1524 .send_message("oh, hi B.".to_string(), cx)
1525 .unwrap()
1526 .detach();
1527 let task = channel.send_message("sup".to_string(), cx).unwrap();
1528 assert_eq!(
1529 channel_messages(channel),
1530 &[
1531 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1532 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1533 ("user_a".to_string(), "sup".to_string(), true)
1534 ]
1535 );
1536 task
1537 })
1538 .await
1539 .unwrap();
1540
1541 channel_b
1542 .condition(&cx_b, |channel, _| {
1543 channel_messages(channel)
1544 == [
1545 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1546 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1547 ("user_a".to_string(), "sup".to_string(), false),
1548 ]
1549 })
1550 .await;
1551
1552 assert_eq!(
1553 server.state().await.channels[&channel_id]
1554 .connection_ids
1555 .len(),
1556 2
1557 );
1558 cx_b.update(|_| drop(channel_b));
1559 server
1560 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1561 .await;
1562
1563 cx_a.update(|_| drop(channel_a));
1564 server
1565 .condition(|state| !state.channels.contains_key(&channel_id))
1566 .await;
1567 }
1568
1569 #[gpui::test]
1570 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1571 cx_a.foreground().forbid_parking();
1572
1573 let mut server = TestServer::start().await;
1574 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1575
1576 let db = &server.app_state.db;
1577 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1578 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1579 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1580 .await
1581 .unwrap();
1582 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1583 .await
1584 .unwrap();
1585
1586 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1587 channels_a
1588 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1589 .await;
1590 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1591 this.get_channel(channel_id.to_proto(), cx).unwrap()
1592 });
1593
1594 // Messages aren't allowed to be too long.
1595 channel_a
1596 .update(&mut cx_a, |channel, cx| {
1597 let long_body = "this is long.\n".repeat(1024);
1598 channel.send_message(long_body, cx).unwrap()
1599 })
1600 .await
1601 .unwrap_err();
1602
1603 // Messages aren't allowed to be blank.
1604 channel_a.update(&mut cx_a, |channel, cx| {
1605 channel.send_message(String::new(), cx).unwrap_err()
1606 });
1607
1608 // Leading and trailing whitespace are trimmed.
1609 channel_a
1610 .update(&mut cx_a, |channel, cx| {
1611 channel
1612 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1613 .unwrap()
1614 })
1615 .await
1616 .unwrap();
1617 assert_eq!(
1618 db.get_channel_messages(channel_id, 10, None)
1619 .await
1620 .unwrap()
1621 .iter()
1622 .map(|m| &m.body)
1623 .collect::<Vec<_>>(),
1624 &["surrounded by whitespace"]
1625 );
1626 }
1627
1628 #[gpui::test]
1629 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1630 cx_a.foreground().forbid_parking();
1631
1632 // Connect to a server as 2 clients.
1633 let mut server = TestServer::start().await;
1634 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1635 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1636 let mut status_b = client_b.status();
1637
1638 // Create an org that includes these 2 users.
1639 let db = &server.app_state.db;
1640 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1641 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1642 .await
1643 .unwrap();
1644 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1645 .await
1646 .unwrap();
1647
1648 // Create a channel that includes all the users.
1649 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1650 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1651 .await
1652 .unwrap();
1653 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1654 .await
1655 .unwrap();
1656 db.create_channel_message(
1657 channel_id,
1658 current_user_id(&user_store_b, &cx_b),
1659 "hello A, it's B.",
1660 OffsetDateTime::now_utc(),
1661 2,
1662 )
1663 .await
1664 .unwrap();
1665
1666 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1667 channels_a
1668 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1669 .await;
1670
1671 channels_a.read_with(&cx_a, |list, _| {
1672 assert_eq!(
1673 list.available_channels().unwrap(),
1674 &[ChannelDetails {
1675 id: channel_id.to_proto(),
1676 name: "test-channel".to_string()
1677 }]
1678 )
1679 });
1680 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1681 this.get_channel(channel_id.to_proto(), cx).unwrap()
1682 });
1683 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1684 channel_a
1685 .condition(&cx_a, |channel, _| {
1686 channel_messages(channel)
1687 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1688 })
1689 .await;
1690
1691 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
1692 channels_b
1693 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1694 .await;
1695 channels_b.read_with(&cx_b, |list, _| {
1696 assert_eq!(
1697 list.available_channels().unwrap(),
1698 &[ChannelDetails {
1699 id: channel_id.to_proto(),
1700 name: "test-channel".to_string()
1701 }]
1702 )
1703 });
1704
1705 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1706 this.get_channel(channel_id.to_proto(), cx).unwrap()
1707 });
1708 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1709 channel_b
1710 .condition(&cx_b, |channel, _| {
1711 channel_messages(channel)
1712 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1713 })
1714 .await;
1715
1716 // Disconnect client B, ensuring we can still access its cached channel data.
1717 server.forbid_connections();
1718 server.disconnect_client(current_user_id(&user_store_b, &cx_b));
1719 while !matches!(
1720 status_b.recv().await,
1721 Some(rpc::Status::ReconnectionError { .. })
1722 ) {}
1723
1724 channels_b.read_with(&cx_b, |channels, _| {
1725 assert_eq!(
1726 channels.available_channels().unwrap(),
1727 [ChannelDetails {
1728 id: channel_id.to_proto(),
1729 name: "test-channel".to_string()
1730 }]
1731 )
1732 });
1733 channel_b.read_with(&cx_b, |channel, _| {
1734 assert_eq!(
1735 channel_messages(channel),
1736 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1737 )
1738 });
1739
1740 // Send a message from client B while it is disconnected.
1741 channel_b
1742 .update(&mut cx_b, |channel, cx| {
1743 let task = channel
1744 .send_message("can you see this?".to_string(), cx)
1745 .unwrap();
1746 assert_eq!(
1747 channel_messages(channel),
1748 &[
1749 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1750 ("user_b".to_string(), "can you see this?".to_string(), true)
1751 ]
1752 );
1753 task
1754 })
1755 .await
1756 .unwrap_err();
1757
1758 // Send a message from client A while B is disconnected.
1759 channel_a
1760 .update(&mut cx_a, |channel, cx| {
1761 channel
1762 .send_message("oh, hi B.".to_string(), cx)
1763 .unwrap()
1764 .detach();
1765 let task = channel.send_message("sup".to_string(), cx).unwrap();
1766 assert_eq!(
1767 channel_messages(channel),
1768 &[
1769 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1770 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1771 ("user_a".to_string(), "sup".to_string(), true)
1772 ]
1773 );
1774 task
1775 })
1776 .await
1777 .unwrap();
1778
1779 // Give client B a chance to reconnect.
1780 server.allow_connections();
1781 cx_b.foreground().advance_clock(Duration::from_secs(10));
1782
1783 // Verify that B sees the new messages upon reconnection, as well as the message client B
1784 // sent while offline.
1785 channel_b
1786 .condition(&cx_b, |channel, _| {
1787 channel_messages(channel)
1788 == [
1789 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1790 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1791 ("user_a".to_string(), "sup".to_string(), false),
1792 ("user_b".to_string(), "can you see this?".to_string(), false),
1793 ]
1794 })
1795 .await;
1796
1797 // Ensure client A and B can communicate normally after reconnection.
1798 channel_a
1799 .update(&mut cx_a, |channel, cx| {
1800 channel.send_message("you online?".to_string(), cx).unwrap()
1801 })
1802 .await
1803 .unwrap();
1804 channel_b
1805 .condition(&cx_b, |channel, _| {
1806 channel_messages(channel)
1807 == [
1808 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1809 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1810 ("user_a".to_string(), "sup".to_string(), false),
1811 ("user_b".to_string(), "can you see this?".to_string(), false),
1812 ("user_a".to_string(), "you online?".to_string(), false),
1813 ]
1814 })
1815 .await;
1816
1817 channel_b
1818 .update(&mut cx_b, |channel, cx| {
1819 channel.send_message("yep".to_string(), cx).unwrap()
1820 })
1821 .await
1822 .unwrap();
1823 channel_a
1824 .condition(&cx_a, |channel, _| {
1825 channel_messages(channel)
1826 == [
1827 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1828 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1829 ("user_a".to_string(), "sup".to_string(), false),
1830 ("user_b".to_string(), "can you see this?".to_string(), false),
1831 ("user_a".to_string(), "you online?".to_string(), false),
1832 ("user_b".to_string(), "yep".to_string(), false),
1833 ]
1834 })
1835 .await;
1836 }
1837
1838 #[gpui::test]
1839 async fn test_collaborators(
1840 mut cx_a: TestAppContext,
1841 mut cx_b: TestAppContext,
1842 mut cx_c: TestAppContext,
1843 ) {
1844 cx_a.foreground().forbid_parking();
1845 let lang_registry = Arc::new(LanguageRegistry::new());
1846
1847 // Connect to a server as 3 clients.
1848 let mut server = TestServer::start().await;
1849 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1850 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1851 let (_client_c, user_store_c) = server.create_client(&mut cx_c, "user_c").await;
1852
1853 let fs = Arc::new(FakeFs::new());
1854
1855 // Share a worktree as client A.
1856 fs.insert_tree(
1857 "/a",
1858 json!({
1859 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1860 }),
1861 )
1862 .await;
1863
1864 let worktree_a = Worktree::open_local(
1865 client_a.clone(),
1866 "/a".as_ref(),
1867 fs.clone(),
1868 lang_registry.clone(),
1869 &mut cx_a.to_async(),
1870 )
1871 .await
1872 .unwrap();
1873
1874 user_store_a
1875 .condition(&cx_a, |user_store, _| {
1876 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
1877 })
1878 .await;
1879 user_store_b
1880 .condition(&cx_b, |user_store, _| {
1881 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
1882 })
1883 .await;
1884 user_store_c
1885 .condition(&cx_c, |user_store, _| {
1886 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
1887 })
1888 .await;
1889
1890 let worktree_id = worktree_a
1891 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1892 .await
1893 .unwrap();
1894
1895 let _worktree_b = Worktree::open_remote(
1896 client_b.clone(),
1897 worktree_id,
1898 lang_registry.clone(),
1899 &mut cx_b.to_async(),
1900 )
1901 .await
1902 .unwrap();
1903
1904 user_store_a
1905 .condition(&cx_a, |user_store, _| {
1906 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
1907 })
1908 .await;
1909 user_store_b
1910 .condition(&cx_b, |user_store, _| {
1911 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
1912 })
1913 .await;
1914 user_store_c
1915 .condition(&cx_c, |user_store, _| {
1916 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
1917 })
1918 .await;
1919
1920 cx_a.update(move |_| drop(worktree_a));
1921 user_store_a
1922 .condition(&cx_a, |user_store, _| collaborators(user_store) == vec![])
1923 .await;
1924 user_store_b
1925 .condition(&cx_b, |user_store, _| collaborators(user_store) == vec![])
1926 .await;
1927 user_store_c
1928 .condition(&cx_c, |user_store, _| collaborators(user_store) == vec![])
1929 .await;
1930
1931 fn collaborators(user_store: &UserStore) -> Vec<(&str, Vec<(&str, Vec<&str>)>)> {
1932 user_store
1933 .collaborators()
1934 .iter()
1935 .map(|collaborator| {
1936 let worktrees = collaborator
1937 .worktrees
1938 .iter()
1939 .map(|w| {
1940 (
1941 w.root_name.as_str(),
1942 w.participants
1943 .iter()
1944 .map(|p| p.github_login.as_str())
1945 .collect(),
1946 )
1947 })
1948 .collect();
1949 (collaborator.user.github_login.as_str(), worktrees)
1950 })
1951 .collect()
1952 }
1953 }
1954
1955 struct TestServer {
1956 peer: Arc<Peer>,
1957 app_state: Arc<AppState>,
1958 server: Arc<Server>,
1959 notifications: mpsc::Receiver<()>,
1960 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
1961 forbid_connections: Arc<AtomicBool>,
1962 _test_db: TestDb,
1963 }
1964
1965 impl TestServer {
1966 async fn start() -> Self {
1967 let test_db = TestDb::new();
1968 let app_state = Self::build_app_state(&test_db).await;
1969 let peer = Peer::new();
1970 let notifications = mpsc::channel(128);
1971 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
1972 Self {
1973 peer,
1974 app_state,
1975 server,
1976 notifications: notifications.1,
1977 connection_killers: Default::default(),
1978 forbid_connections: Default::default(),
1979 _test_db: test_db,
1980 }
1981 }
1982
1983 async fn create_client(
1984 &mut self,
1985 cx: &mut TestAppContext,
1986 name: &str,
1987 ) -> (Arc<Client>, ModelHandle<UserStore>) {
1988 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
1989 let client_name = name.to_string();
1990 let mut client = Client::new();
1991 let server = self.server.clone();
1992 let connection_killers = self.connection_killers.clone();
1993 let forbid_connections = self.forbid_connections.clone();
1994 Arc::get_mut(&mut client)
1995 .unwrap()
1996 .override_authenticate(move |cx| {
1997 cx.spawn(|_| async move {
1998 let access_token = "the-token".to_string();
1999 Ok(Credentials {
2000 user_id: user_id.0 as u64,
2001 access_token,
2002 })
2003 })
2004 })
2005 .override_establish_connection(move |credentials, cx| {
2006 assert_eq!(credentials.user_id, user_id.0 as u64);
2007 assert_eq!(credentials.access_token, "the-token");
2008
2009 let server = server.clone();
2010 let connection_killers = connection_killers.clone();
2011 let forbid_connections = forbid_connections.clone();
2012 let client_name = client_name.clone();
2013 cx.spawn(move |cx| async move {
2014 if forbid_connections.load(SeqCst) {
2015 Err(EstablishConnectionError::other(anyhow!(
2016 "server is forbidding connections"
2017 )))
2018 } else {
2019 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2020 connection_killers.lock().insert(user_id, kill_conn);
2021 cx.background()
2022 .spawn(server.handle_connection(server_conn, client_name, user_id))
2023 .detach();
2024 Ok(client_conn)
2025 }
2026 })
2027 });
2028
2029 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2030 client
2031 .authenticate_and_connect(&cx.to_async())
2032 .await
2033 .unwrap();
2034
2035 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2036 let mut authed_user =
2037 user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2038 while authed_user.recv().await.unwrap().is_none() {}
2039
2040 (client, user_store)
2041 }
2042
2043 fn disconnect_client(&self, user_id: UserId) {
2044 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2045 let _ = kill_conn.try_send(Some(()));
2046 }
2047 }
2048
2049 fn forbid_connections(&self) {
2050 self.forbid_connections.store(true, SeqCst);
2051 }
2052
2053 fn allow_connections(&self) {
2054 self.forbid_connections.store(false, SeqCst);
2055 }
2056
2057 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2058 let mut config = Config::default();
2059 config.session_secret = "a".repeat(32);
2060 config.database_url = test_db.url.clone();
2061 let github_client = github::AppClient::test();
2062 Arc::new(AppState {
2063 db: test_db.db().clone(),
2064 handlebars: Default::default(),
2065 auth_client: auth::build_client("", ""),
2066 repo_client: github::RepoClient::test(&github_client),
2067 github_client,
2068 config,
2069 })
2070 }
2071
2072 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
2073 self.server.store.read().await
2074 }
2075
2076 async fn condition<F>(&mut self, mut predicate: F)
2077 where
2078 F: FnMut(&Store) -> bool,
2079 {
2080 async_std::future::timeout(Duration::from_millis(500), async {
2081 while !(predicate)(&*self.server.store.read().await) {
2082 self.notifications.recv().await;
2083 }
2084 })
2085 .await
2086 .expect("condition timed out");
2087 }
2088 }
2089
2090 impl Drop for TestServer {
2091 fn drop(&mut self) {
2092 task::block_on(self.peer.reset());
2093 }
2094 }
2095
2096 fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2097 UserId::from_proto(
2098 user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2099 )
2100 }
2101
2102 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2103 channel
2104 .messages()
2105 .cursor::<(), ()>()
2106 .map(|m| {
2107 (
2108 m.sender.github_login.clone(),
2109 m.body.clone(),
2110 m.is_pending(),
2111 )
2112 })
2113 .collect()
2114 }
2115
2116 struct EmptyView;
2117
2118 impl gpui::Entity for EmptyView {
2119 type Event = ();
2120 }
2121
2122 impl gpui::View for EmptyView {
2123 fn ui_name() -> &'static str {
2124 "empty view"
2125 }
2126
2127 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2128 gpui::Element::boxed(gpui::elements::Empty)
2129 }
2130 }
2131}