1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use crate::errors::TideResultExt;
7use anyhow::anyhow;
8use async_std::{sync::RwLock, task};
9use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
10use futures::{future::BoxFuture, FutureExt};
11use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
12use sha1::{Digest as _, Sha1};
13use std::{
14 any::TypeId,
15 collections::{hash_map, HashMap, HashSet},
16 future::Future,
17 mem,
18 sync::Arc,
19 time::Instant,
20};
21use surf::StatusCode;
22use tide::log;
23use tide::{
24 http::headers::{HeaderName, CONNECTION, UPGRADE},
25 Request, Response,
26};
27use time::OffsetDateTime;
28use zrpc::{
29 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
30 Connection, ConnectionId, Peer, TypedEnvelope,
31};
32
33type ReplicaId = u16;
34
35type MessageHandler = Box<
36 dyn Send
37 + Sync
38 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
39>;
40
41pub struct Server {
42 peer: Arc<Peer>,
43 state: RwLock<ServerState>,
44 app_state: Arc<AppState>,
45 handlers: HashMap<TypeId, MessageHandler>,
46 notifications: Option<mpsc::Sender<()>>,
47}
48
49#[derive(Default)]
50struct ServerState {
51 connections: HashMap<ConnectionId, ConnectionState>,
52 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
53 worktrees: HashMap<u64, Worktree>,
54 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
55 channels: HashMap<ChannelId, Channel>,
56 next_worktree_id: u64,
57}
58
59struct ConnectionState {
60 user_id: UserId,
61 worktrees: HashSet<u64>,
62 channels: HashSet<ChannelId>,
63}
64
65struct Worktree {
66 host_connection_id: ConnectionId,
67 collaborator_user_ids: Vec<UserId>,
68 root_name: String,
69 share: Option<WorktreeShare>,
70}
71
72struct WorktreeShare {
73 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
74 active_replica_ids: HashSet<ReplicaId>,
75 entries: HashMap<u64, proto::Entry>,
76}
77
78#[derive(Default)]
79struct Channel {
80 connection_ids: HashSet<ConnectionId>,
81}
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85
86impl Server {
87 pub fn new(
88 app_state: Arc<AppState>,
89 peer: Arc<Peer>,
90 notifications: Option<mpsc::Sender<()>>,
91 ) -> Arc<Self> {
92 let mut server = Self {
93 peer,
94 app_state,
95 state: Default::default(),
96 handlers: Default::default(),
97 notifications,
98 };
99
100 server
101 .add_handler(Server::ping)
102 .add_handler(Server::open_worktree)
103 .add_handler(Server::handle_close_worktree)
104 .add_handler(Server::share_worktree)
105 .add_handler(Server::unshare_worktree)
106 .add_handler(Server::join_worktree)
107 .add_handler(Server::update_worktree)
108 .add_handler(Server::open_buffer)
109 .add_handler(Server::close_buffer)
110 .add_handler(Server::update_buffer)
111 .add_handler(Server::buffer_saved)
112 .add_handler(Server::save_buffer)
113 .add_handler(Server::get_channels)
114 .add_handler(Server::get_users)
115 .add_handler(Server::join_channel)
116 .add_handler(Server::leave_channel)
117 .add_handler(Server::send_channel_message)
118 .add_handler(Server::get_channel_messages);
119
120 Arc::new(server)
121 }
122
123 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
124 where
125 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
126 Fut: 'static + Send + Future<Output = tide::Result<()>>,
127 M: EnvelopedMessage,
128 {
129 let prev_handler = self.handlers.insert(
130 TypeId::of::<M>(),
131 Box::new(move |server, envelope| {
132 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
133 (handler)(server, *envelope).boxed()
134 }),
135 );
136 if prev_handler.is_some() {
137 panic!("registered a handler for the same message twice");
138 }
139 self
140 }
141
142 pub fn handle_connection(
143 self: &Arc<Self>,
144 connection: Connection,
145 addr: String,
146 user_id: UserId,
147 ) -> impl Future<Output = ()> {
148 let this = self.clone();
149 async move {
150 let (connection_id, handle_io, mut incoming_rx) =
151 this.peer.add_connection(connection).await;
152 this.add_connection(connection_id, user_id).await;
153 if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
154 log::error!("error updating collaborators for {:?}: {}", user_id, err);
155 }
156
157 let handle_io = handle_io.fuse();
158 futures::pin_mut!(handle_io);
159 loop {
160 let next_message = incoming_rx.recv().fuse();
161 futures::pin_mut!(next_message);
162 futures::select_biased! {
163 message = next_message => {
164 if let Some(message) = message {
165 let start_time = Instant::now();
166 log::info!("RPC message received: {}", message.payload_type_name());
167 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
168 if let Err(err) = (handler)(this.clone(), message).await {
169 log::error!("error handling message: {:?}", err);
170 } else {
171 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
172 }
173
174 if let Some(mut notifications) = this.notifications.clone() {
175 let _ = notifications.send(()).await;
176 }
177 } else {
178 log::warn!("unhandled message: {}", message.payload_type_name());
179 }
180 } else {
181 log::info!("rpc connection closed {:?}", addr);
182 break;
183 }
184 }
185 handle_io = handle_io => {
186 if let Err(err) = handle_io {
187 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
188 }
189 break;
190 }
191 }
192 }
193
194 if let Err(err) = this.sign_out(connection_id).await {
195 log::error!("error signing out connection {:?} - {:?}", addr, err);
196 }
197 }
198 }
199
200 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
201 self.peer.disconnect(connection_id).await;
202 self.remove_connection(connection_id).await?;
203 Ok(())
204 }
205
206 // Add a new connection associated with a given user.
207 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
208 let mut state = self.state.write().await;
209 state.connections.insert(
210 connection_id,
211 ConnectionState {
212 user_id,
213 worktrees: Default::default(),
214 channels: Default::default(),
215 },
216 );
217 state
218 .connections_by_user_id
219 .entry(user_id)
220 .or_default()
221 .insert(connection_id);
222 }
223
224 // Remove the given connection and its association with any worktrees.
225 async fn remove_connection(
226 self: &Arc<Server>,
227 connection_id: ConnectionId,
228 ) -> tide::Result<()> {
229 let mut worktree_ids = Vec::new();
230 let mut state = self.state.write().await;
231 if let Some(connection) = state.connections.remove(&connection_id) {
232 worktree_ids = connection.worktrees.into_iter().collect();
233
234 for channel_id in connection.channels {
235 if let Some(channel) = state.channels.get_mut(&channel_id) {
236 channel.connection_ids.remove(&connection_id);
237 }
238 }
239
240 let user_connections = state
241 .connections_by_user_id
242 .get_mut(&connection.user_id)
243 .unwrap();
244 user_connections.remove(&connection_id);
245 if user_connections.is_empty() {
246 state.connections_by_user_id.remove(&connection.user_id);
247 }
248 }
249
250 drop(state);
251 for worktree_id in worktree_ids {
252 self.close_worktree(worktree_id, connection_id).await?;
253 }
254
255 Ok(())
256 }
257
258 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
259 self.peer.respond(request.receipt(), proto::Ack {}).await?;
260 Ok(())
261 }
262
263 async fn open_worktree(
264 self: Arc<Server>,
265 request: TypedEnvelope<proto::OpenWorktree>,
266 ) -> tide::Result<()> {
267 let receipt = request.receipt();
268 let host_user_id = self
269 .state
270 .read()
271 .await
272 .user_id_for_connection(request.sender_id)?;
273
274 let mut collaborator_user_ids = HashSet::new();
275 collaborator_user_ids.insert(host_user_id);
276 for github_login in request.payload.collaborator_logins {
277 match self.app_state.db.create_user(&github_login, false).await {
278 Ok(collaborator_user_id) => {
279 collaborator_user_ids.insert(collaborator_user_id);
280 }
281 Err(err) => {
282 let message = err.to_string();
283 self.peer
284 .respond_with_error(receipt, proto::Error { message })
285 .await?;
286 return Ok(());
287 }
288 }
289 }
290
291 let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
292 let worktree_id = self.state.write().await.add_worktree(Worktree {
293 host_connection_id: request.sender_id,
294 collaborator_user_ids: collaborator_user_ids.clone(),
295 root_name: request.payload.root_name,
296 share: None,
297 });
298
299 self.peer
300 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
301 .await?;
302 self.update_collaborators_for_users(&collaborator_user_ids)
303 .await?;
304
305 Ok(())
306 }
307
308 async fn share_worktree(
309 self: Arc<Server>,
310 mut request: TypedEnvelope<proto::ShareWorktree>,
311 ) -> tide::Result<()> {
312 let worktree = request
313 .payload
314 .worktree
315 .as_mut()
316 .ok_or_else(|| anyhow!("missing worktree"))?;
317 let entries = mem::take(&mut worktree.entries)
318 .into_iter()
319 .map(|entry| (entry.id, entry))
320 .collect();
321
322 let mut state = self.state.write().await;
323 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
324 worktree.share = Some(WorktreeShare {
325 guest_connection_ids: Default::default(),
326 active_replica_ids: Default::default(),
327 entries,
328 });
329 let collaborator_user_ids = worktree.collaborator_user_ids.clone();
330
331 drop(state);
332 self.peer
333 .respond(request.receipt(), proto::ShareWorktreeResponse {})
334 .await?;
335 self.update_collaborators_for_users(&collaborator_user_ids)
336 .await?;
337 } else {
338 self.peer
339 .respond_with_error(
340 request.receipt(),
341 proto::Error {
342 message: "no such worktree".to_string(),
343 },
344 )
345 .await?;
346 }
347 Ok(())
348 }
349
350 async fn unshare_worktree(
351 self: Arc<Server>,
352 request: TypedEnvelope<proto::UnshareWorktree>,
353 ) -> tide::Result<()> {
354 let worktree_id = request.payload.worktree_id;
355
356 let connection_ids;
357 let collaborator_user_ids;
358 {
359 let mut state = self.state.write().await;
360 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
361 if worktree.host_connection_id != request.sender_id {
362 return Err(anyhow!("no such worktree"))?;
363 }
364
365 connection_ids = worktree.connection_ids();
366 collaborator_user_ids = worktree.collaborator_user_ids.clone();
367
368 worktree.share.take();
369 for connection_id in &connection_ids {
370 if let Some(connection) = state.connections.get_mut(connection_id) {
371 connection.worktrees.remove(&worktree_id);
372 }
373 }
374 }
375
376 broadcast(request.sender_id, connection_ids, |conn_id| {
377 self.peer
378 .send(conn_id, proto::UnshareWorktree { worktree_id })
379 })
380 .await?;
381 self.update_collaborators_for_users(&collaborator_user_ids)
382 .await?;
383
384 Ok(())
385 }
386
387 async fn join_worktree(
388 self: Arc<Server>,
389 request: TypedEnvelope<proto::JoinWorktree>,
390 ) -> tide::Result<()> {
391 let worktree_id = request.payload.worktree_id;
392 let user_id = self
393 .state
394 .read()
395 .await
396 .user_id_for_connection(request.sender_id)?;
397
398 let response;
399 let connection_ids;
400 let collaborator_user_ids;
401 let mut state = self.state.write().await;
402 match state.join_worktree(request.sender_id, user_id, worktree_id) {
403 Ok((peer_replica_id, worktree)) => {
404 let share = worktree.share()?;
405 let peer_count = share.guest_connection_ids.len();
406 let mut peers = Vec::with_capacity(peer_count);
407 peers.push(proto::Peer {
408 peer_id: worktree.host_connection_id.0,
409 replica_id: 0,
410 });
411 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
412 if *peer_conn_id != request.sender_id {
413 peers.push(proto::Peer {
414 peer_id: peer_conn_id.0,
415 replica_id: *peer_replica_id as u32,
416 });
417 }
418 }
419 response = proto::JoinWorktreeResponse {
420 worktree: Some(proto::Worktree {
421 id: worktree_id,
422 root_name: worktree.root_name.clone(),
423 entries: share.entries.values().cloned().collect(),
424 }),
425 replica_id: peer_replica_id as u32,
426 peers,
427 };
428 connection_ids = worktree.connection_ids();
429 collaborator_user_ids = worktree.collaborator_user_ids.clone();
430 }
431 Err(error) => {
432 self.peer
433 .respond_with_error(
434 request.receipt(),
435 proto::Error {
436 message: error.to_string(),
437 },
438 )
439 .await?;
440 return Ok(());
441 }
442 }
443
444 drop(state);
445 broadcast(request.sender_id, connection_ids, |conn_id| {
446 self.peer.send(
447 conn_id,
448 proto::AddPeer {
449 worktree_id,
450 peer: Some(proto::Peer {
451 peer_id: request.sender_id.0,
452 replica_id: response.replica_id,
453 }),
454 },
455 )
456 })
457 .await?;
458 self.peer.respond(request.receipt(), response).await?;
459 self.update_collaborators_for_users(&collaborator_user_ids)
460 .await?;
461
462 Ok(())
463 }
464
465 async fn handle_close_worktree(
466 self: Arc<Server>,
467 request: TypedEnvelope<proto::CloseWorktree>,
468 ) -> tide::Result<()> {
469 self.close_worktree(request.payload.worktree_id, request.sender_id)
470 .await
471 }
472
473 async fn close_worktree(
474 self: &Arc<Server>,
475 worktree_id: u64,
476 sender_conn_id: ConnectionId,
477 ) -> tide::Result<()> {
478 let connection_ids;
479 let collaborator_user_ids;
480 let mut is_host = false;
481 let mut is_guest = false;
482 {
483 let mut state = self.state.write().await;
484 let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
485 connection_ids = worktree.connection_ids();
486 collaborator_user_ids = worktree.collaborator_user_ids.clone();
487
488 if worktree.host_connection_id == sender_conn_id {
489 is_host = true;
490 state.remove_worktree(worktree_id);
491 } else {
492 let share = worktree.share_mut()?;
493 if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
494 is_guest = true;
495 share.active_replica_ids.remove(&replica_id);
496 }
497 }
498 }
499
500 if is_host {
501 broadcast(sender_conn_id, connection_ids, |conn_id| {
502 self.peer
503 .send(conn_id, proto::UnshareWorktree { worktree_id })
504 })
505 .await?;
506 } else if is_guest {
507 broadcast(sender_conn_id, connection_ids, |conn_id| {
508 self.peer.send(
509 conn_id,
510 proto::RemovePeer {
511 worktree_id,
512 peer_id: sender_conn_id.0,
513 },
514 )
515 })
516 .await?
517 }
518 self.update_collaborators_for_users(&collaborator_user_ids)
519 .await?;
520 Ok(())
521 }
522
523 async fn update_worktree(
524 self: Arc<Server>,
525 request: TypedEnvelope<proto::UpdateWorktree>,
526 ) -> tide::Result<()> {
527 {
528 let mut state = self.state.write().await;
529 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
530 let share = worktree.share_mut()?;
531
532 for entry_id in &request.payload.removed_entries {
533 share.entries.remove(&entry_id);
534 }
535
536 for entry in &request.payload.updated_entries {
537 share.entries.insert(entry.id, entry.clone());
538 }
539 }
540
541 self.broadcast_in_worktree(request.payload.worktree_id, &request)
542 .await?;
543 Ok(())
544 }
545
546 async fn open_buffer(
547 self: Arc<Server>,
548 request: TypedEnvelope<proto::OpenBuffer>,
549 ) -> tide::Result<()> {
550 let receipt = request.receipt();
551 let worktree_id = request.payload.worktree_id;
552 let host_connection_id = self
553 .state
554 .read()
555 .await
556 .read_worktree(worktree_id, request.sender_id)?
557 .host_connection_id;
558
559 let response = self
560 .peer
561 .forward_request(request.sender_id, host_connection_id, request.payload)
562 .await?;
563 self.peer.respond(receipt, response).await?;
564 Ok(())
565 }
566
567 async fn close_buffer(
568 self: Arc<Server>,
569 request: TypedEnvelope<proto::CloseBuffer>,
570 ) -> tide::Result<()> {
571 let host_connection_id = self
572 .state
573 .read()
574 .await
575 .read_worktree(request.payload.worktree_id, request.sender_id)?
576 .host_connection_id;
577
578 self.peer
579 .forward_send(request.sender_id, host_connection_id, request.payload)
580 .await?;
581
582 Ok(())
583 }
584
585 async fn save_buffer(
586 self: Arc<Server>,
587 request: TypedEnvelope<proto::SaveBuffer>,
588 ) -> tide::Result<()> {
589 let host;
590 let guests;
591 {
592 let state = self.state.read().await;
593 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
594 host = worktree.host_connection_id;
595 guests = worktree
596 .share()?
597 .guest_connection_ids
598 .keys()
599 .copied()
600 .collect::<Vec<_>>();
601 }
602
603 let sender = request.sender_id;
604 let receipt = request.receipt();
605 let response = self
606 .peer
607 .forward_request(sender, host, request.payload.clone())
608 .await?;
609
610 broadcast(host, guests, |conn_id| {
611 let response = response.clone();
612 let peer = &self.peer;
613 async move {
614 if conn_id == sender {
615 peer.respond(receipt, response).await
616 } else {
617 peer.forward_send(host, conn_id, response).await
618 }
619 }
620 })
621 .await?;
622
623 Ok(())
624 }
625
626 async fn update_buffer(
627 self: Arc<Server>,
628 request: TypedEnvelope<proto::UpdateBuffer>,
629 ) -> tide::Result<()> {
630 self.broadcast_in_worktree(request.payload.worktree_id, &request)
631 .await?;
632 self.peer.respond(request.receipt(), proto::Ack {}).await?;
633 Ok(())
634 }
635
636 async fn buffer_saved(
637 self: Arc<Server>,
638 request: TypedEnvelope<proto::BufferSaved>,
639 ) -> tide::Result<()> {
640 self.broadcast_in_worktree(request.payload.worktree_id, &request)
641 .await
642 }
643
644 async fn get_channels(
645 self: Arc<Server>,
646 request: TypedEnvelope<proto::GetChannels>,
647 ) -> tide::Result<()> {
648 let user_id = self
649 .state
650 .read()
651 .await
652 .user_id_for_connection(request.sender_id)?;
653 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
654 self.peer
655 .respond(
656 request.receipt(),
657 proto::GetChannelsResponse {
658 channels: channels
659 .into_iter()
660 .map(|chan| proto::Channel {
661 id: chan.id.to_proto(),
662 name: chan.name,
663 })
664 .collect(),
665 },
666 )
667 .await?;
668 Ok(())
669 }
670
671 async fn get_users(
672 self: Arc<Server>,
673 request: TypedEnvelope<proto::GetUsers>,
674 ) -> tide::Result<()> {
675 let receipt = request.receipt();
676 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
677 let users = self
678 .app_state
679 .db
680 .get_users_by_ids(user_ids)
681 .await?
682 .into_iter()
683 .map(|user| proto::User {
684 id: user.id.to_proto(),
685 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
686 github_login: user.github_login,
687 })
688 .collect();
689 self.peer
690 .respond(receipt, proto::GetUsersResponse { users })
691 .await?;
692 Ok(())
693 }
694
695 async fn update_collaborators_for_users<'a>(
696 self: &Arc<Server>,
697 user_ids: impl IntoIterator<Item = &'a UserId>,
698 ) -> tide::Result<()> {
699 let mut send_futures = Vec::new();
700
701 let state = self.state.read().await;
702 for user_id in user_ids {
703 let mut collaborators = HashMap::new();
704 for worktree_id in state
705 .visible_worktrees_by_user_id
706 .get(&user_id)
707 .unwrap_or(&HashSet::new())
708 {
709 let worktree = &state.worktrees[worktree_id];
710
711 let mut guests = HashSet::new();
712 if let Ok(share) = worktree.share() {
713 for guest_connection_id in share.guest_connection_ids.keys() {
714 let user_id = state
715 .user_id_for_connection(*guest_connection_id)
716 .context("stale worktree guest connection")?;
717 guests.insert(user_id.to_proto());
718 }
719 }
720
721 let host_user_id = state
722 .user_id_for_connection(worktree.host_connection_id)
723 .context("stale worktree host connection")?;
724 let host =
725 collaborators
726 .entry(host_user_id)
727 .or_insert_with(|| proto::Collaborator {
728 user_id: host_user_id.to_proto(),
729 worktrees: Vec::new(),
730 });
731 host.worktrees.push(proto::WorktreeMetadata {
732 root_name: worktree.root_name.clone(),
733 is_shared: worktree.share().is_ok(),
734 participants: guests.into_iter().collect(),
735 });
736 }
737
738 let collaborators = collaborators.into_values().collect::<Vec<_>>();
739 for connection_id in state.user_connection_ids(*user_id) {
740 send_futures.push(self.peer.send(
741 connection_id,
742 proto::UpdateCollaborators {
743 collaborators: collaborators.clone(),
744 },
745 ));
746 }
747 }
748
749 drop(state);
750 futures::future::try_join_all(send_futures).await?;
751
752 Ok(())
753 }
754
755 async fn join_channel(
756 self: Arc<Self>,
757 request: TypedEnvelope<proto::JoinChannel>,
758 ) -> tide::Result<()> {
759 let user_id = self
760 .state
761 .read()
762 .await
763 .user_id_for_connection(request.sender_id)?;
764 let channel_id = ChannelId::from_proto(request.payload.channel_id);
765 if !self
766 .app_state
767 .db
768 .can_user_access_channel(user_id, channel_id)
769 .await?
770 {
771 Err(anyhow!("access denied"))?;
772 }
773
774 self.state
775 .write()
776 .await
777 .join_channel(request.sender_id, channel_id);
778 let messages = self
779 .app_state
780 .db
781 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
782 .await?
783 .into_iter()
784 .map(|msg| proto::ChannelMessage {
785 id: msg.id.to_proto(),
786 body: msg.body,
787 timestamp: msg.sent_at.unix_timestamp() as u64,
788 sender_id: msg.sender_id.to_proto(),
789 nonce: Some(msg.nonce.as_u128().into()),
790 })
791 .collect::<Vec<_>>();
792 self.peer
793 .respond(
794 request.receipt(),
795 proto::JoinChannelResponse {
796 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
797 messages,
798 },
799 )
800 .await?;
801 Ok(())
802 }
803
804 async fn leave_channel(
805 self: Arc<Self>,
806 request: TypedEnvelope<proto::LeaveChannel>,
807 ) -> tide::Result<()> {
808 let user_id = self
809 .state
810 .read()
811 .await
812 .user_id_for_connection(request.sender_id)?;
813 let channel_id = ChannelId::from_proto(request.payload.channel_id);
814 if !self
815 .app_state
816 .db
817 .can_user_access_channel(user_id, channel_id)
818 .await?
819 {
820 Err(anyhow!("access denied"))?;
821 }
822
823 self.state
824 .write()
825 .await
826 .leave_channel(request.sender_id, channel_id);
827
828 Ok(())
829 }
830
831 async fn send_channel_message(
832 self: Arc<Self>,
833 request: TypedEnvelope<proto::SendChannelMessage>,
834 ) -> tide::Result<()> {
835 let receipt = request.receipt();
836 let channel_id = ChannelId::from_proto(request.payload.channel_id);
837 let user_id;
838 let connection_ids;
839 {
840 let state = self.state.read().await;
841 user_id = state.user_id_for_connection(request.sender_id)?;
842 if let Some(channel) = state.channels.get(&channel_id) {
843 connection_ids = channel.connection_ids();
844 } else {
845 return Ok(());
846 }
847 }
848
849 // Validate the message body.
850 let body = request.payload.body.trim().to_string();
851 if body.len() > MAX_MESSAGE_LEN {
852 self.peer
853 .respond_with_error(
854 receipt,
855 proto::Error {
856 message: "message is too long".to_string(),
857 },
858 )
859 .await?;
860 return Ok(());
861 }
862 if body.is_empty() {
863 self.peer
864 .respond_with_error(
865 receipt,
866 proto::Error {
867 message: "message can't be blank".to_string(),
868 },
869 )
870 .await?;
871 return Ok(());
872 }
873
874 let timestamp = OffsetDateTime::now_utc();
875 let nonce = if let Some(nonce) = request.payload.nonce {
876 nonce
877 } else {
878 self.peer
879 .respond_with_error(
880 receipt,
881 proto::Error {
882 message: "nonce can't be blank".to_string(),
883 },
884 )
885 .await?;
886 return Ok(());
887 };
888
889 let message_id = self
890 .app_state
891 .db
892 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
893 .await?
894 .to_proto();
895 let message = proto::ChannelMessage {
896 sender_id: user_id.to_proto(),
897 id: message_id,
898 body,
899 timestamp: timestamp.unix_timestamp() as u64,
900 nonce: Some(nonce),
901 };
902 broadcast(request.sender_id, connection_ids, |conn_id| {
903 self.peer.send(
904 conn_id,
905 proto::ChannelMessageSent {
906 channel_id: channel_id.to_proto(),
907 message: Some(message.clone()),
908 },
909 )
910 })
911 .await?;
912 self.peer
913 .respond(
914 receipt,
915 proto::SendChannelMessageResponse {
916 message: Some(message),
917 },
918 )
919 .await?;
920 Ok(())
921 }
922
923 async fn get_channel_messages(
924 self: Arc<Self>,
925 request: TypedEnvelope<proto::GetChannelMessages>,
926 ) -> tide::Result<()> {
927 let user_id = self
928 .state
929 .read()
930 .await
931 .user_id_for_connection(request.sender_id)?;
932 let channel_id = ChannelId::from_proto(request.payload.channel_id);
933 if !self
934 .app_state
935 .db
936 .can_user_access_channel(user_id, channel_id)
937 .await?
938 {
939 Err(anyhow!("access denied"))?;
940 }
941
942 let messages = self
943 .app_state
944 .db
945 .get_channel_messages(
946 channel_id,
947 MESSAGE_COUNT_PER_PAGE,
948 Some(MessageId::from_proto(request.payload.before_message_id)),
949 )
950 .await?
951 .into_iter()
952 .map(|msg| proto::ChannelMessage {
953 id: msg.id.to_proto(),
954 body: msg.body,
955 timestamp: msg.sent_at.unix_timestamp() as u64,
956 sender_id: msg.sender_id.to_proto(),
957 nonce: Some(msg.nonce.as_u128().into()),
958 })
959 .collect::<Vec<_>>();
960 self.peer
961 .respond(
962 request.receipt(),
963 proto::GetChannelMessagesResponse {
964 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
965 messages,
966 },
967 )
968 .await?;
969 Ok(())
970 }
971
972 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
973 &self,
974 worktree_id: u64,
975 message: &TypedEnvelope<T>,
976 ) -> tide::Result<()> {
977 let connection_ids = self
978 .state
979 .read()
980 .await
981 .read_worktree(worktree_id, message.sender_id)?
982 .connection_ids();
983
984 broadcast(message.sender_id, connection_ids, |conn_id| {
985 self.peer
986 .forward_send(message.sender_id, conn_id, message.payload.clone())
987 })
988 .await?;
989
990 Ok(())
991 }
992}
993
994pub async fn broadcast<F, T>(
995 sender_id: ConnectionId,
996 receiver_ids: Vec<ConnectionId>,
997 mut f: F,
998) -> anyhow::Result<()>
999where
1000 F: FnMut(ConnectionId) -> T,
1001 T: Future<Output = anyhow::Result<()>>,
1002{
1003 let futures = receiver_ids
1004 .into_iter()
1005 .filter(|id| *id != sender_id)
1006 .map(|id| f(id));
1007 futures::future::try_join_all(futures).await?;
1008 Ok(())
1009}
1010
1011impl ServerState {
1012 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1013 if let Some(connection) = self.connections.get_mut(&connection_id) {
1014 connection.channels.insert(channel_id);
1015 self.channels
1016 .entry(channel_id)
1017 .or_default()
1018 .connection_ids
1019 .insert(connection_id);
1020 }
1021 }
1022
1023 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1024 if let Some(connection) = self.connections.get_mut(&connection_id) {
1025 connection.channels.remove(&channel_id);
1026 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1027 entry.get_mut().connection_ids.remove(&connection_id);
1028 if entry.get_mut().connection_ids.is_empty() {
1029 entry.remove();
1030 }
1031 }
1032 }
1033 }
1034
1035 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1036 Ok(self
1037 .connections
1038 .get(&connection_id)
1039 .ok_or_else(|| anyhow!("unknown connection"))?
1040 .user_id)
1041 }
1042
1043 fn user_connection_ids<'a>(
1044 &'a self,
1045 user_id: UserId,
1046 ) -> impl 'a + Iterator<Item = ConnectionId> {
1047 self.connections_by_user_id
1048 .get(&user_id)
1049 .into_iter()
1050 .flatten()
1051 .copied()
1052 }
1053
1054 // Add the given connection as a guest of the given worktree
1055 fn join_worktree(
1056 &mut self,
1057 connection_id: ConnectionId,
1058 user_id: UserId,
1059 worktree_id: u64,
1060 ) -> tide::Result<(ReplicaId, &Worktree)> {
1061 let connection = self
1062 .connections
1063 .get_mut(&connection_id)
1064 .ok_or_else(|| anyhow!("no such connection"))?;
1065 let worktree = self
1066 .worktrees
1067 .get_mut(&worktree_id)
1068 .ok_or_else(|| anyhow!("no such worktree"))?;
1069 if !worktree.collaborator_user_ids.contains(&user_id) {
1070 Err(anyhow!("no such worktree"))?;
1071 }
1072
1073 let share = worktree.share_mut()?;
1074 connection.worktrees.insert(worktree_id);
1075
1076 let mut replica_id = 1;
1077 while share.active_replica_ids.contains(&replica_id) {
1078 replica_id += 1;
1079 }
1080 share.active_replica_ids.insert(replica_id);
1081 share.guest_connection_ids.insert(connection_id, replica_id);
1082 return Ok((replica_id, worktree));
1083 }
1084
1085 fn read_worktree(
1086 &self,
1087 worktree_id: u64,
1088 connection_id: ConnectionId,
1089 ) -> tide::Result<&Worktree> {
1090 let worktree = self
1091 .worktrees
1092 .get(&worktree_id)
1093 .ok_or_else(|| anyhow!("worktree not found"))?;
1094
1095 if worktree.host_connection_id == connection_id
1096 || worktree
1097 .share()?
1098 .guest_connection_ids
1099 .contains_key(&connection_id)
1100 {
1101 Ok(worktree)
1102 } else {
1103 Err(anyhow!(
1104 "{} is not a member of worktree {}",
1105 connection_id,
1106 worktree_id
1107 ))?
1108 }
1109 }
1110
1111 fn write_worktree(
1112 &mut self,
1113 worktree_id: u64,
1114 connection_id: ConnectionId,
1115 ) -> tide::Result<&mut Worktree> {
1116 let worktree = self
1117 .worktrees
1118 .get_mut(&worktree_id)
1119 .ok_or_else(|| anyhow!("worktree not found"))?;
1120
1121 if worktree.host_connection_id == connection_id
1122 || worktree.share.as_ref().map_or(false, |share| {
1123 share.guest_connection_ids.contains_key(&connection_id)
1124 })
1125 {
1126 Ok(worktree)
1127 } else {
1128 Err(anyhow!(
1129 "{} is not a member of worktree {}",
1130 connection_id,
1131 worktree_id
1132 ))?
1133 }
1134 }
1135
1136 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1137 let worktree_id = self.next_worktree_id;
1138 for collaborator_user_id in &worktree.collaborator_user_ids {
1139 self.visible_worktrees_by_user_id
1140 .entry(*collaborator_user_id)
1141 .or_default()
1142 .insert(worktree_id);
1143 }
1144 self.next_worktree_id += 1;
1145 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1146 connection.worktrees.insert(worktree_id);
1147 }
1148 self.worktrees.insert(worktree_id, worktree);
1149
1150 #[cfg(test)]
1151 self.check_invariants();
1152
1153 worktree_id
1154 }
1155
1156 fn remove_worktree(&mut self, worktree_id: u64) {
1157 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1158 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1159 connection.worktrees.remove(&worktree_id);
1160 }
1161 if let Some(share) = worktree.share {
1162 for connection_id in share.guest_connection_ids.keys() {
1163 if let Some(connection) = self.connections.get_mut(connection_id) {
1164 connection.worktrees.remove(&worktree_id);
1165 }
1166 }
1167 }
1168 for collaborator_user_id in worktree.collaborator_user_ids {
1169 if let Some(visible_worktrees) = self
1170 .visible_worktrees_by_user_id
1171 .get_mut(&collaborator_user_id)
1172 {
1173 visible_worktrees.remove(&worktree_id);
1174 }
1175 }
1176
1177 #[cfg(test)]
1178 self.check_invariants();
1179 }
1180
1181 #[cfg(test)]
1182 fn check_invariants(&self) {
1183 for (connection_id, connection) in &self.connections {
1184 for worktree_id in &connection.worktrees {
1185 let worktree = &self.worktrees.get(&worktree_id).unwrap();
1186 if worktree.host_connection_id != *connection_id {
1187 assert!(worktree
1188 .share()
1189 .unwrap()
1190 .guest_connection_ids
1191 .contains_key(connection_id));
1192 }
1193 }
1194 for channel_id in &connection.channels {
1195 let channel = self.channels.get(channel_id).unwrap();
1196 assert!(channel.connection_ids.contains(connection_id));
1197 }
1198 assert!(self
1199 .connections_by_user_id
1200 .get(&connection.user_id)
1201 .unwrap()
1202 .contains(connection_id));
1203 }
1204
1205 for (user_id, connection_ids) in &self.connections_by_user_id {
1206 for connection_id in connection_ids {
1207 assert_eq!(
1208 self.connections.get(connection_id).unwrap().user_id,
1209 *user_id
1210 );
1211 }
1212 }
1213
1214 for (worktree_id, worktree) in &self.worktrees {
1215 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
1216 assert!(host_connection.worktrees.contains(worktree_id));
1217
1218 for collaborator_id in &worktree.collaborator_user_ids {
1219 let visible_worktree_ids = self
1220 .visible_worktrees_by_user_id
1221 .get(collaborator_id)
1222 .unwrap();
1223 assert!(visible_worktree_ids.contains(worktree_id));
1224 }
1225
1226 if let Some(share) = &worktree.share {
1227 for guest_connection_id in share.guest_connection_ids.keys() {
1228 let guest_connection = self.connections.get(guest_connection_id).unwrap();
1229 assert!(guest_connection.worktrees.contains(worktree_id));
1230 }
1231 assert_eq!(
1232 share.active_replica_ids.len(),
1233 share.guest_connection_ids.len(),
1234 );
1235 assert_eq!(
1236 share.active_replica_ids,
1237 share
1238 .guest_connection_ids
1239 .values()
1240 .copied()
1241 .collect::<HashSet<_>>(),
1242 );
1243 }
1244 }
1245
1246 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
1247 for worktree_id in visible_worktree_ids {
1248 let worktree = self.worktrees.get(worktree_id).unwrap();
1249 assert!(worktree.collaborator_user_ids.contains(user_id));
1250 }
1251 }
1252
1253 for (channel_id, channel) in &self.channels {
1254 for connection_id in &channel.connection_ids {
1255 let connection = self.connections.get(connection_id).unwrap();
1256 assert!(connection.channels.contains(channel_id));
1257 }
1258 }
1259 }
1260}
1261
1262impl Worktree {
1263 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1264 if let Some(share) = &self.share {
1265 share
1266 .guest_connection_ids
1267 .keys()
1268 .copied()
1269 .chain(Some(self.host_connection_id))
1270 .collect()
1271 } else {
1272 vec![self.host_connection_id]
1273 }
1274 }
1275
1276 fn share(&self) -> tide::Result<&WorktreeShare> {
1277 Ok(self
1278 .share
1279 .as_ref()
1280 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1281 }
1282
1283 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1284 Ok(self
1285 .share
1286 .as_mut()
1287 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1288 }
1289}
1290
1291impl Channel {
1292 fn connection_ids(&self) -> Vec<ConnectionId> {
1293 self.connection_ids.iter().copied().collect()
1294 }
1295}
1296
1297pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1298 let server = Server::new(app.state().clone(), rpc.clone(), None);
1299 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1300 let user_id = request.ext::<UserId>().copied();
1301 let server = server.clone();
1302 async move {
1303 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1304
1305 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1306 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1307 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1308
1309 if !upgrade_requested {
1310 return Ok(Response::new(StatusCode::UpgradeRequired));
1311 }
1312
1313 let header = match request.header("Sec-Websocket-Key") {
1314 Some(h) => h.as_str(),
1315 None => return Err(anyhow!("expected sec-websocket-key"))?,
1316 };
1317
1318 let mut response = Response::new(StatusCode::SwitchingProtocols);
1319 response.insert_header(UPGRADE, "websocket");
1320 response.insert_header(CONNECTION, "Upgrade");
1321 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1322 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1323 response.insert_header("Sec-Websocket-Version", "13");
1324
1325 let http_res: &mut tide::http::Response = response.as_mut();
1326 let upgrade_receiver = http_res.recv_upgrade().await;
1327 let addr = request.remote().unwrap_or("unknown").to_string();
1328 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1329 task::spawn(async move {
1330 if let Some(stream) = upgrade_receiver.await {
1331 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1332 }
1333 });
1334
1335 Ok(response)
1336 }
1337 });
1338}
1339
1340fn header_contains_ignore_case<T>(
1341 request: &tide::Request<T>,
1342 header_name: HeaderName,
1343 value: &str,
1344) -> bool {
1345 request
1346 .header(header_name)
1347 .map(|h| {
1348 h.as_str()
1349 .split(',')
1350 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1351 })
1352 .unwrap_or(false)
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357 use super::*;
1358 use crate::{
1359 auth,
1360 db::{tests::TestDb, UserId},
1361 github, AppState, Config,
1362 };
1363 use async_std::{sync::RwLockReadGuard, task};
1364 use gpui::{ModelHandle, TestAppContext};
1365 use parking_lot::Mutex;
1366 use postage::{mpsc, watch};
1367 use serde_json::json;
1368 use sqlx::types::time::OffsetDateTime;
1369 use std::{
1370 path::Path,
1371 sync::{
1372 atomic::{AtomicBool, Ordering::SeqCst},
1373 Arc,
1374 },
1375 time::Duration,
1376 };
1377 use zed::{
1378 channel::{Channel, ChannelDetails, ChannelList},
1379 editor::{Editor, Insert},
1380 fs::{FakeFs, Fs as _},
1381 language::LanguageRegistry,
1382 rpc::{self, Client, Credentials, EstablishConnectionError},
1383 settings,
1384 test::FakeHttpClient,
1385 user::UserStore,
1386 worktree::Worktree,
1387 };
1388 use zrpc::Peer;
1389
1390 #[gpui::test]
1391 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1392 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1393 let settings = cx_b.read(settings::test).1;
1394 let lang_registry = Arc::new(LanguageRegistry::new());
1395
1396 // Connect to a server as 2 clients.
1397 let mut server = TestServer::start().await;
1398 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1399 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1400
1401 cx_a.foreground().forbid_parking();
1402
1403 // Share a local worktree as client A
1404 let fs = Arc::new(FakeFs::new());
1405 fs.insert_tree(
1406 "/a",
1407 json!({
1408 ".zed.toml": r#"collaborators = ["user_b"]"#,
1409 "a.txt": "a-contents",
1410 "b.txt": "b-contents",
1411 }),
1412 )
1413 .await;
1414 let worktree_a = Worktree::open_local(
1415 client_a.clone(),
1416 "/a".as_ref(),
1417 fs,
1418 lang_registry.clone(),
1419 &mut cx_a.to_async(),
1420 )
1421 .await
1422 .unwrap();
1423 worktree_a
1424 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1425 .await;
1426 let worktree_id = worktree_a
1427 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1428 .await
1429 .unwrap();
1430
1431 // Join that worktree as client B, and see that a guest has joined as client A.
1432 let worktree_b = Worktree::open_remote(
1433 client_b.clone(),
1434 worktree_id,
1435 lang_registry.clone(),
1436 &mut cx_b.to_async(),
1437 )
1438 .await
1439 .unwrap();
1440 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1441 worktree_a
1442 .condition(&cx_a, |tree, _| {
1443 tree.peers()
1444 .values()
1445 .any(|replica_id| *replica_id == replica_id_b)
1446 })
1447 .await;
1448
1449 // Open the same file as client B and client A.
1450 let buffer_b = worktree_b
1451 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1452 .await
1453 .unwrap();
1454 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1455 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1456 let buffer_a = worktree_a
1457 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1458 .await
1459 .unwrap();
1460
1461 // Create a selection set as client B and see that selection set as client A.
1462 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1463 buffer_a
1464 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1465 .await;
1466
1467 // Edit the buffer as client B and see that edit as client A.
1468 editor_b.update(&mut cx_b, |editor, cx| {
1469 editor.insert(&Insert("ok, ".into()), cx)
1470 });
1471 buffer_a
1472 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1473 .await;
1474
1475 // Remove the selection set as client B, see those selections disappear as client A.
1476 cx_b.update(move |_| drop(editor_b));
1477 buffer_a
1478 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1479 .await;
1480
1481 // Close the buffer as client A, see that the buffer is closed.
1482 cx_a.update(move |_| drop(buffer_a));
1483 worktree_a
1484 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1485 .await;
1486
1487 // Dropping the worktree removes client B from client A's peers.
1488 cx_b.update(move |_| drop(worktree_b));
1489 worktree_a
1490 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1491 .await;
1492 }
1493
1494 #[gpui::test]
1495 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1496 mut cx_a: TestAppContext,
1497 mut cx_b: TestAppContext,
1498 mut cx_c: TestAppContext,
1499 ) {
1500 cx_a.foreground().forbid_parking();
1501 let lang_registry = Arc::new(LanguageRegistry::new());
1502
1503 // Connect to a server as 3 clients.
1504 let mut server = TestServer::start().await;
1505 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1506 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1507 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1508
1509 let fs = Arc::new(FakeFs::new());
1510
1511 // Share a worktree as client A.
1512 fs.insert_tree(
1513 "/a",
1514 json!({
1515 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1516 "file1": "",
1517 "file2": ""
1518 }),
1519 )
1520 .await;
1521
1522 let worktree_a = Worktree::open_local(
1523 client_a.clone(),
1524 "/a".as_ref(),
1525 fs.clone(),
1526 lang_registry.clone(),
1527 &mut cx_a.to_async(),
1528 )
1529 .await
1530 .unwrap();
1531 worktree_a
1532 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1533 .await;
1534 let worktree_id = worktree_a
1535 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1536 .await
1537 .unwrap();
1538
1539 // Join that worktree as clients B and C.
1540 let worktree_b = Worktree::open_remote(
1541 client_b.clone(),
1542 worktree_id,
1543 lang_registry.clone(),
1544 &mut cx_b.to_async(),
1545 )
1546 .await
1547 .unwrap();
1548 let worktree_c = Worktree::open_remote(
1549 client_c.clone(),
1550 worktree_id,
1551 lang_registry.clone(),
1552 &mut cx_c.to_async(),
1553 )
1554 .await
1555 .unwrap();
1556
1557 // Open and edit a buffer as both guests B and C.
1558 let buffer_b = worktree_b
1559 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1560 .await
1561 .unwrap();
1562 let buffer_c = worktree_c
1563 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1564 .await
1565 .unwrap();
1566 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1567 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1568
1569 // Open and edit that buffer as the host.
1570 let buffer_a = worktree_a
1571 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1572 .await
1573 .unwrap();
1574
1575 buffer_a
1576 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1577 .await;
1578 buffer_a.update(&mut cx_a, |buf, cx| {
1579 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1580 });
1581
1582 // Wait for edits to propagate
1583 buffer_a
1584 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1585 .await;
1586 buffer_b
1587 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1588 .await;
1589 buffer_c
1590 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1591 .await;
1592
1593 // Edit the buffer as the host and concurrently save as guest B.
1594 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1595 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1596 save_b.await.unwrap();
1597 assert_eq!(
1598 fs.load("/a/file1".as_ref()).await.unwrap(),
1599 "hi-a, i-am-c, i-am-b, i-am-a"
1600 );
1601 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1602 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1603 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1604
1605 // Make changes on host's file system, see those changes on the guests.
1606 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1607 .await
1608 .unwrap();
1609 fs.insert_file(Path::new("/a/file4"), "4".into())
1610 .await
1611 .unwrap();
1612
1613 worktree_b
1614 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1615 .await;
1616 worktree_c
1617 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1618 .await;
1619 worktree_b.read_with(&cx_b, |tree, _| {
1620 assert_eq!(
1621 tree.paths()
1622 .map(|p| p.to_string_lossy())
1623 .collect::<Vec<_>>(),
1624 &[".zed.toml", "file1", "file3", "file4"]
1625 )
1626 });
1627 worktree_c.read_with(&cx_c, |tree, _| {
1628 assert_eq!(
1629 tree.paths()
1630 .map(|p| p.to_string_lossy())
1631 .collect::<Vec<_>>(),
1632 &[".zed.toml", "file1", "file3", "file4"]
1633 )
1634 });
1635 }
1636
1637 #[gpui::test]
1638 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1639 cx_a.foreground().forbid_parking();
1640 let lang_registry = Arc::new(LanguageRegistry::new());
1641
1642 // Connect to a server as 2 clients.
1643 let mut server = TestServer::start().await;
1644 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1645 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1646
1647 // Share a local worktree as client A
1648 let fs = Arc::new(FakeFs::new());
1649 fs.insert_tree(
1650 "/dir",
1651 json!({
1652 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1653 "a.txt": "a-contents",
1654 }),
1655 )
1656 .await;
1657
1658 let worktree_a = Worktree::open_local(
1659 client_a.clone(),
1660 "/dir".as_ref(),
1661 fs,
1662 lang_registry.clone(),
1663 &mut cx_a.to_async(),
1664 )
1665 .await
1666 .unwrap();
1667 worktree_a
1668 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1669 .await;
1670 let worktree_id = worktree_a
1671 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1672 .await
1673 .unwrap();
1674
1675 // Join that worktree as client B, and see that a guest has joined as client A.
1676 let worktree_b = Worktree::open_remote(
1677 client_b.clone(),
1678 worktree_id,
1679 lang_registry.clone(),
1680 &mut cx_b.to_async(),
1681 )
1682 .await
1683 .unwrap();
1684
1685 let buffer_b = worktree_b
1686 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1687 .await
1688 .unwrap();
1689 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1690
1691 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1692 buffer_b.read_with(&cx_b, |buf, _| {
1693 assert!(buf.is_dirty());
1694 assert!(!buf.has_conflict());
1695 });
1696
1697 buffer_b
1698 .update(&mut cx_b, |buf, cx| buf.save(cx))
1699 .unwrap()
1700 .await
1701 .unwrap();
1702 worktree_b
1703 .condition(&cx_b, |_, cx| {
1704 buffer_b.read(cx).file().unwrap().mtime != mtime
1705 })
1706 .await;
1707 buffer_b.read_with(&cx_b, |buf, _| {
1708 assert!(!buf.is_dirty());
1709 assert!(!buf.has_conflict());
1710 });
1711
1712 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1713 buffer_b.read_with(&cx_b, |buf, _| {
1714 assert!(buf.is_dirty());
1715 assert!(!buf.has_conflict());
1716 });
1717 }
1718
1719 #[gpui::test]
1720 async fn test_editing_while_guest_opens_buffer(
1721 mut cx_a: TestAppContext,
1722 mut cx_b: TestAppContext,
1723 ) {
1724 cx_a.foreground().forbid_parking();
1725 let lang_registry = Arc::new(LanguageRegistry::new());
1726
1727 // Connect to a server as 2 clients.
1728 let mut server = TestServer::start().await;
1729 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1730 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1731
1732 // Share a local worktree as client A
1733 let fs = Arc::new(FakeFs::new());
1734 fs.insert_tree(
1735 "/dir",
1736 json!({
1737 ".zed.toml": r#"collaborators = ["user_b"]"#,
1738 "a.txt": "a-contents",
1739 }),
1740 )
1741 .await;
1742 let worktree_a = Worktree::open_local(
1743 client_a.clone(),
1744 "/dir".as_ref(),
1745 fs,
1746 lang_registry.clone(),
1747 &mut cx_a.to_async(),
1748 )
1749 .await
1750 .unwrap();
1751 worktree_a
1752 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1753 .await;
1754 let worktree_id = worktree_a
1755 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1756 .await
1757 .unwrap();
1758
1759 // Join that worktree as client B, and see that a guest has joined as client A.
1760 let worktree_b = Worktree::open_remote(
1761 client_b.clone(),
1762 worktree_id,
1763 lang_registry.clone(),
1764 &mut cx_b.to_async(),
1765 )
1766 .await
1767 .unwrap();
1768
1769 let buffer_a = worktree_a
1770 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1771 .await
1772 .unwrap();
1773 let buffer_b = cx_b
1774 .background()
1775 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1776
1777 task::yield_now().await;
1778 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1779
1780 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1781 let buffer_b = buffer_b.await.unwrap();
1782 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1783 }
1784
1785 #[gpui::test]
1786 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1787 cx_a.foreground().forbid_parking();
1788 let lang_registry = Arc::new(LanguageRegistry::new());
1789
1790 // Connect to a server as 2 clients.
1791 let mut server = TestServer::start().await;
1792 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1793 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1794
1795 // Share a local worktree as client A
1796 let fs = Arc::new(FakeFs::new());
1797 fs.insert_tree(
1798 "/a",
1799 json!({
1800 ".zed.toml": r#"collaborators = ["user_b"]"#,
1801 "a.txt": "a-contents",
1802 "b.txt": "b-contents",
1803 }),
1804 )
1805 .await;
1806 let worktree_a = Worktree::open_local(
1807 client_a.clone(),
1808 "/a".as_ref(),
1809 fs,
1810 lang_registry.clone(),
1811 &mut cx_a.to_async(),
1812 )
1813 .await
1814 .unwrap();
1815 worktree_a
1816 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1817 .await;
1818 let worktree_id = worktree_a
1819 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1820 .await
1821 .unwrap();
1822
1823 // Join that worktree as client B, and see that a guest has joined as client A.
1824 let _worktree_b = Worktree::open_remote(
1825 client_b.clone(),
1826 worktree_id,
1827 lang_registry.clone(),
1828 &mut cx_b.to_async(),
1829 )
1830 .await
1831 .unwrap();
1832 worktree_a
1833 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1834 .await;
1835
1836 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1837 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1838 worktree_a
1839 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1840 .await;
1841 }
1842
1843 #[gpui::test]
1844 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1845 cx_a.foreground().forbid_parking();
1846
1847 // Connect to a server as 2 clients.
1848 let mut server = TestServer::start().await;
1849 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1850 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1851
1852 // Create an org that includes these 2 users.
1853 let db = &server.app_state.db;
1854 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1855 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1856 .await
1857 .unwrap();
1858 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1859 .await
1860 .unwrap();
1861
1862 // Create a channel that includes all the users.
1863 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1864 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1865 .await
1866 .unwrap();
1867 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1868 .await
1869 .unwrap();
1870 db.create_channel_message(
1871 channel_id,
1872 current_user_id(&user_store_b, &cx_b),
1873 "hello A, it's B.",
1874 OffsetDateTime::now_utc(),
1875 1,
1876 )
1877 .await
1878 .unwrap();
1879
1880 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1881 channels_a
1882 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1883 .await;
1884 channels_a.read_with(&cx_a, |list, _| {
1885 assert_eq!(
1886 list.available_channels().unwrap(),
1887 &[ChannelDetails {
1888 id: channel_id.to_proto(),
1889 name: "test-channel".to_string()
1890 }]
1891 )
1892 });
1893 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1894 this.get_channel(channel_id.to_proto(), cx).unwrap()
1895 });
1896 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1897 channel_a
1898 .condition(&cx_a, |channel, _| {
1899 channel_messages(channel)
1900 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1901 })
1902 .await;
1903
1904 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1905 channels_b
1906 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1907 .await;
1908 channels_b.read_with(&cx_b, |list, _| {
1909 assert_eq!(
1910 list.available_channels().unwrap(),
1911 &[ChannelDetails {
1912 id: channel_id.to_proto(),
1913 name: "test-channel".to_string()
1914 }]
1915 )
1916 });
1917
1918 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1919 this.get_channel(channel_id.to_proto(), cx).unwrap()
1920 });
1921 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1922 channel_b
1923 .condition(&cx_b, |channel, _| {
1924 channel_messages(channel)
1925 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1926 })
1927 .await;
1928
1929 channel_a
1930 .update(&mut cx_a, |channel, cx| {
1931 channel
1932 .send_message("oh, hi B.".to_string(), cx)
1933 .unwrap()
1934 .detach();
1935 let task = channel.send_message("sup".to_string(), cx).unwrap();
1936 assert_eq!(
1937 channel_messages(channel),
1938 &[
1939 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1940 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1941 ("user_a".to_string(), "sup".to_string(), true)
1942 ]
1943 );
1944 task
1945 })
1946 .await
1947 .unwrap();
1948
1949 channel_b
1950 .condition(&cx_b, |channel, _| {
1951 channel_messages(channel)
1952 == [
1953 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1954 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1955 ("user_a".to_string(), "sup".to_string(), false),
1956 ]
1957 })
1958 .await;
1959
1960 assert_eq!(
1961 server.state().await.channels[&channel_id]
1962 .connection_ids
1963 .len(),
1964 2
1965 );
1966 cx_b.update(|_| drop(channel_b));
1967 server
1968 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1969 .await;
1970
1971 cx_a.update(|_| drop(channel_a));
1972 server
1973 .condition(|state| !state.channels.contains_key(&channel_id))
1974 .await;
1975 }
1976
1977 #[gpui::test]
1978 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1979 cx_a.foreground().forbid_parking();
1980
1981 let mut server = TestServer::start().await;
1982 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1983
1984 let db = &server.app_state.db;
1985 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1986 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1987 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1988 .await
1989 .unwrap();
1990 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1991 .await
1992 .unwrap();
1993
1994 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1995 channels_a
1996 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1997 .await;
1998 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1999 this.get_channel(channel_id.to_proto(), cx).unwrap()
2000 });
2001
2002 // Messages aren't allowed to be too long.
2003 channel_a
2004 .update(&mut cx_a, |channel, cx| {
2005 let long_body = "this is long.\n".repeat(1024);
2006 channel.send_message(long_body, cx).unwrap()
2007 })
2008 .await
2009 .unwrap_err();
2010
2011 // Messages aren't allowed to be blank.
2012 channel_a.update(&mut cx_a, |channel, cx| {
2013 channel.send_message(String::new(), cx).unwrap_err()
2014 });
2015
2016 // Leading and trailing whitespace are trimmed.
2017 channel_a
2018 .update(&mut cx_a, |channel, cx| {
2019 channel
2020 .send_message("\n surrounded by whitespace \n".to_string(), cx)
2021 .unwrap()
2022 })
2023 .await
2024 .unwrap();
2025 assert_eq!(
2026 db.get_channel_messages(channel_id, 10, None)
2027 .await
2028 .unwrap()
2029 .iter()
2030 .map(|m| &m.body)
2031 .collect::<Vec<_>>(),
2032 &["surrounded by whitespace"]
2033 );
2034 }
2035
2036 #[gpui::test]
2037 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
2038 cx_a.foreground().forbid_parking();
2039
2040 // Connect to a server as 2 clients.
2041 let mut server = TestServer::start().await;
2042 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
2043 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
2044 let mut status_b = client_b.status();
2045
2046 // Create an org that includes these 2 users.
2047 let db = &server.app_state.db;
2048 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
2049 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
2050 .await
2051 .unwrap();
2052 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
2053 .await
2054 .unwrap();
2055
2056 // Create a channel that includes all the users.
2057 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
2058 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
2059 .await
2060 .unwrap();
2061 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
2062 .await
2063 .unwrap();
2064 db.create_channel_message(
2065 channel_id,
2066 current_user_id(&user_store_b, &cx_b),
2067 "hello A, it's B.",
2068 OffsetDateTime::now_utc(),
2069 2,
2070 )
2071 .await
2072 .unwrap();
2073
2074 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
2075 channels_a
2076 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
2077 .await;
2078
2079 channels_a.read_with(&cx_a, |list, _| {
2080 assert_eq!(
2081 list.available_channels().unwrap(),
2082 &[ChannelDetails {
2083 id: channel_id.to_proto(),
2084 name: "test-channel".to_string()
2085 }]
2086 )
2087 });
2088 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2089 this.get_channel(channel_id.to_proto(), cx).unwrap()
2090 });
2091 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2092 channel_a
2093 .condition(&cx_a, |channel, _| {
2094 channel_messages(channel)
2095 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2096 })
2097 .await;
2098
2099 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2100 channels_b
2101 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2102 .await;
2103 channels_b.read_with(&cx_b, |list, _| {
2104 assert_eq!(
2105 list.available_channels().unwrap(),
2106 &[ChannelDetails {
2107 id: channel_id.to_proto(),
2108 name: "test-channel".to_string()
2109 }]
2110 )
2111 });
2112
2113 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2114 this.get_channel(channel_id.to_proto(), cx).unwrap()
2115 });
2116 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2117 channel_b
2118 .condition(&cx_b, |channel, _| {
2119 channel_messages(channel)
2120 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2121 })
2122 .await;
2123
2124 // Disconnect client B, ensuring we can still access its cached channel data.
2125 server.forbid_connections();
2126 server.disconnect_client(current_user_id(&user_store_b, &cx_b));
2127 while !matches!(
2128 status_b.recv().await,
2129 Some(rpc::Status::ReconnectionError { .. })
2130 ) {}
2131
2132 channels_b.read_with(&cx_b, |channels, _| {
2133 assert_eq!(
2134 channels.available_channels().unwrap(),
2135 [ChannelDetails {
2136 id: channel_id.to_proto(),
2137 name: "test-channel".to_string()
2138 }]
2139 )
2140 });
2141 channel_b.read_with(&cx_b, |channel, _| {
2142 assert_eq!(
2143 channel_messages(channel),
2144 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2145 )
2146 });
2147
2148 // Send a message from client B while it is disconnected.
2149 channel_b
2150 .update(&mut cx_b, |channel, cx| {
2151 let task = channel
2152 .send_message("can you see this?".to_string(), cx)
2153 .unwrap();
2154 assert_eq!(
2155 channel_messages(channel),
2156 &[
2157 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2158 ("user_b".to_string(), "can you see this?".to_string(), true)
2159 ]
2160 );
2161 task
2162 })
2163 .await
2164 .unwrap_err();
2165
2166 // Send a message from client A while B is disconnected.
2167 channel_a
2168 .update(&mut cx_a, |channel, cx| {
2169 channel
2170 .send_message("oh, hi B.".to_string(), cx)
2171 .unwrap()
2172 .detach();
2173 let task = channel.send_message("sup".to_string(), cx).unwrap();
2174 assert_eq!(
2175 channel_messages(channel),
2176 &[
2177 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2178 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2179 ("user_a".to_string(), "sup".to_string(), true)
2180 ]
2181 );
2182 task
2183 })
2184 .await
2185 .unwrap();
2186
2187 // Give client B a chance to reconnect.
2188 server.allow_connections();
2189 cx_b.foreground().advance_clock(Duration::from_secs(10));
2190
2191 // Verify that B sees the new messages upon reconnection, as well as the message client B
2192 // sent while offline.
2193 channel_b
2194 .condition(&cx_b, |channel, _| {
2195 channel_messages(channel)
2196 == [
2197 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2198 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2199 ("user_a".to_string(), "sup".to_string(), false),
2200 ("user_b".to_string(), "can you see this?".to_string(), false),
2201 ]
2202 })
2203 .await;
2204
2205 // Ensure client A and B can communicate normally after reconnection.
2206 channel_a
2207 .update(&mut cx_a, |channel, cx| {
2208 channel.send_message("you online?".to_string(), cx).unwrap()
2209 })
2210 .await
2211 .unwrap();
2212 channel_b
2213 .condition(&cx_b, |channel, _| {
2214 channel_messages(channel)
2215 == [
2216 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2217 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2218 ("user_a".to_string(), "sup".to_string(), false),
2219 ("user_b".to_string(), "can you see this?".to_string(), false),
2220 ("user_a".to_string(), "you online?".to_string(), false),
2221 ]
2222 })
2223 .await;
2224
2225 channel_b
2226 .update(&mut cx_b, |channel, cx| {
2227 channel.send_message("yep".to_string(), cx).unwrap()
2228 })
2229 .await
2230 .unwrap();
2231 channel_a
2232 .condition(&cx_a, |channel, _| {
2233 channel_messages(channel)
2234 == [
2235 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2236 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2237 ("user_a".to_string(), "sup".to_string(), false),
2238 ("user_b".to_string(), "can you see this?".to_string(), false),
2239 ("user_a".to_string(), "you online?".to_string(), false),
2240 ("user_b".to_string(), "yep".to_string(), false),
2241 ]
2242 })
2243 .await;
2244 }
2245
2246 #[gpui::test]
2247 async fn test_collaborators(
2248 mut cx_a: TestAppContext,
2249 mut cx_b: TestAppContext,
2250 mut cx_c: TestAppContext,
2251 ) {
2252 cx_a.foreground().forbid_parking();
2253 let lang_registry = Arc::new(LanguageRegistry::new());
2254
2255 // Connect to a server as 3 clients.
2256 let mut server = TestServer::start().await;
2257 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
2258 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
2259 let (_client_c, user_store_c) = server.create_client(&mut cx_c, "user_c").await;
2260
2261 let fs = Arc::new(FakeFs::new());
2262
2263 // Share a worktree as client A.
2264 fs.insert_tree(
2265 "/a",
2266 json!({
2267 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
2268 }),
2269 )
2270 .await;
2271
2272 let worktree_a = Worktree::open_local(
2273 client_a.clone(),
2274 "/a".as_ref(),
2275 fs.clone(),
2276 lang_registry.clone(),
2277 &mut cx_a.to_async(),
2278 )
2279 .await
2280 .unwrap();
2281
2282 user_store_a
2283 .condition(&cx_a, |user_store, _| {
2284 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2285 })
2286 .await;
2287 user_store_b
2288 .condition(&cx_b, |user_store, _| {
2289 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2290 })
2291 .await;
2292 user_store_c
2293 .condition(&cx_c, |user_store, _| {
2294 collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2295 })
2296 .await;
2297
2298 let worktree_id = worktree_a
2299 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
2300 .await
2301 .unwrap();
2302
2303 let _worktree_b = Worktree::open_remote(
2304 client_b.clone(),
2305 worktree_id,
2306 lang_registry.clone(),
2307 &mut cx_b.to_async(),
2308 )
2309 .await
2310 .unwrap();
2311
2312 user_store_a
2313 .condition(&cx_a, |user_store, _| {
2314 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2315 })
2316 .await;
2317 user_store_b
2318 .condition(&cx_b, |user_store, _| {
2319 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2320 })
2321 .await;
2322 user_store_c
2323 .condition(&cx_c, |user_store, _| {
2324 collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2325 })
2326 .await;
2327
2328 cx_a.update(move |_| drop(worktree_a));
2329 user_store_a
2330 .condition(&cx_a, |user_store, _| collaborators(user_store) == vec![])
2331 .await;
2332 user_store_b
2333 .condition(&cx_b, |user_store, _| collaborators(user_store) == vec![])
2334 .await;
2335 user_store_c
2336 .condition(&cx_c, |user_store, _| collaborators(user_store) == vec![])
2337 .await;
2338
2339 fn collaborators(user_store: &UserStore) -> Vec<(&str, Vec<(&str, Vec<&str>)>)> {
2340 user_store
2341 .collaborators()
2342 .iter()
2343 .map(|collaborator| {
2344 let worktrees = collaborator
2345 .worktrees
2346 .iter()
2347 .map(|w| {
2348 (
2349 w.root_name.as_str(),
2350 w.participants
2351 .iter()
2352 .map(|p| p.github_login.as_str())
2353 .collect(),
2354 )
2355 })
2356 .collect();
2357 (collaborator.user.github_login.as_str(), worktrees)
2358 })
2359 .collect()
2360 }
2361 }
2362
2363 struct TestServer {
2364 peer: Arc<Peer>,
2365 app_state: Arc<AppState>,
2366 server: Arc<Server>,
2367 notifications: mpsc::Receiver<()>,
2368 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2369 forbid_connections: Arc<AtomicBool>,
2370 _test_db: TestDb,
2371 }
2372
2373 impl TestServer {
2374 async fn start() -> Self {
2375 let test_db = TestDb::new();
2376 let app_state = Self::build_app_state(&test_db).await;
2377 let peer = Peer::new();
2378 let notifications = mpsc::channel(128);
2379 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2380 Self {
2381 peer,
2382 app_state,
2383 server,
2384 notifications: notifications.1,
2385 connection_killers: Default::default(),
2386 forbid_connections: Default::default(),
2387 _test_db: test_db,
2388 }
2389 }
2390
2391 async fn create_client(
2392 &mut self,
2393 cx: &mut TestAppContext,
2394 name: &str,
2395 ) -> (Arc<Client>, ModelHandle<UserStore>) {
2396 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2397 let client_name = name.to_string();
2398 let mut client = Client::new();
2399 let server = self.server.clone();
2400 let connection_killers = self.connection_killers.clone();
2401 let forbid_connections = self.forbid_connections.clone();
2402 Arc::get_mut(&mut client)
2403 .unwrap()
2404 .override_authenticate(move |cx| {
2405 cx.spawn(|_| async move {
2406 let access_token = "the-token".to_string();
2407 Ok(Credentials {
2408 user_id: user_id.0 as u64,
2409 access_token,
2410 })
2411 })
2412 })
2413 .override_establish_connection(move |credentials, cx| {
2414 assert_eq!(credentials.user_id, user_id.0 as u64);
2415 assert_eq!(credentials.access_token, "the-token");
2416
2417 let server = server.clone();
2418 let connection_killers = connection_killers.clone();
2419 let forbid_connections = forbid_connections.clone();
2420 let client_name = client_name.clone();
2421 cx.spawn(move |cx| async move {
2422 if forbid_connections.load(SeqCst) {
2423 Err(EstablishConnectionError::other(anyhow!(
2424 "server is forbidding connections"
2425 )))
2426 } else {
2427 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2428 connection_killers.lock().insert(user_id, kill_conn);
2429 cx.background()
2430 .spawn(server.handle_connection(server_conn, client_name, user_id))
2431 .detach();
2432 Ok(client_conn)
2433 }
2434 })
2435 });
2436
2437 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2438 client
2439 .authenticate_and_connect(&cx.to_async())
2440 .await
2441 .unwrap();
2442
2443 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2444 let mut authed_user =
2445 user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2446 while authed_user.recv().await.unwrap().is_none() {}
2447
2448 (client, user_store)
2449 }
2450
2451 fn disconnect_client(&self, user_id: UserId) {
2452 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2453 let _ = kill_conn.try_send(Some(()));
2454 }
2455 }
2456
2457 fn forbid_connections(&self) {
2458 self.forbid_connections.store(true, SeqCst);
2459 }
2460
2461 fn allow_connections(&self) {
2462 self.forbid_connections.store(false, SeqCst);
2463 }
2464
2465 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2466 let mut config = Config::default();
2467 config.session_secret = "a".repeat(32);
2468 config.database_url = test_db.url.clone();
2469 let github_client = github::AppClient::test();
2470 Arc::new(AppState {
2471 db: test_db.db().clone(),
2472 handlebars: Default::default(),
2473 auth_client: auth::build_client("", ""),
2474 repo_client: github::RepoClient::test(&github_client),
2475 github_client,
2476 config,
2477 })
2478 }
2479
2480 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2481 self.server.state.read().await
2482 }
2483
2484 async fn condition<F>(&mut self, mut predicate: F)
2485 where
2486 F: FnMut(&ServerState) -> bool,
2487 {
2488 async_std::future::timeout(Duration::from_millis(500), async {
2489 while !(predicate)(&*self.server.state.read().await) {
2490 self.notifications.recv().await;
2491 }
2492 })
2493 .await
2494 .expect("condition timed out");
2495 }
2496 }
2497
2498 impl Drop for TestServer {
2499 fn drop(&mut self) {
2500 task::block_on(self.peer.reset());
2501 }
2502 }
2503
2504 fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2505 UserId::from_proto(
2506 user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2507 )
2508 }
2509
2510 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2511 channel
2512 .messages()
2513 .cursor::<(), ()>()
2514 .map(|m| {
2515 (
2516 m.sender.github_login.clone(),
2517 m.body.clone(),
2518 m.is_pending(),
2519 )
2520 })
2521 .collect()
2522 }
2523
2524 struct EmptyView;
2525
2526 impl gpui::Entity for EmptyView {
2527 type Event = ();
2528 }
2529
2530 impl gpui::View for EmptyView {
2531 fn ui_name() -> &'static str {
2532 "empty view"
2533 }
2534
2535 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2536 gpui::Element::boxed(gpui::elements::Empty)
2537 }
2538 }
2539}