1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
29 Connection, ConnectionId, Peer, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34type MessageHandler = Box<
35 dyn Send
36 + Sync
37 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
38>;
39
40pub struct Server {
41 peer: Arc<Peer>,
42 state: RwLock<ServerState>,
43 app_state: Arc<AppState>,
44 handlers: HashMap<TypeId, MessageHandler>,
45 notifications: Option<mpsc::Sender<()>>,
46}
47
48#[derive(Default)]
49struct ServerState {
50 connections: HashMap<ConnectionId, ConnectionState>,
51 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
52 pub worktrees: HashMap<u64, Worktree>,
53 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
54 channels: HashMap<ChannelId, Channel>,
55 next_worktree_id: u64,
56}
57
58struct ConnectionState {
59 user_id: UserId,
60 worktrees: HashSet<u64>,
61 channels: HashSet<ChannelId>,
62}
63
64struct Worktree {
65 host_connection_id: ConnectionId,
66 collaborator_user_ids: Vec<UserId>,
67 root_name: String,
68 share: Option<WorktreeShare>,
69}
70
71struct WorktreeShare {
72 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
73 active_replica_ids: HashSet<ReplicaId>,
74 entries: HashMap<u64, proto::Entry>,
75}
76
77#[derive(Default)]
78struct Channel {
79 connection_ids: HashSet<ConnectionId>,
80}
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84
85impl Server {
86 pub fn new(
87 app_state: Arc<AppState>,
88 peer: Arc<Peer>,
89 notifications: Option<mpsc::Sender<()>>,
90 ) -> Arc<Self> {
91 let mut server = Self {
92 peer,
93 app_state,
94 state: Default::default(),
95 handlers: Default::default(),
96 notifications,
97 };
98
99 server
100 .add_handler(Server::ping)
101 .add_handler(Server::open_worktree)
102 .add_handler(Server::handle_close_worktree)
103 .add_handler(Server::share_worktree)
104 .add_handler(Server::unshare_worktree)
105 .add_handler(Server::join_worktree)
106 .add_handler(Server::update_worktree)
107 .add_handler(Server::open_buffer)
108 .add_handler(Server::close_buffer)
109 .add_handler(Server::update_buffer)
110 .add_handler(Server::buffer_saved)
111 .add_handler(Server::save_buffer)
112 .add_handler(Server::get_channels)
113 .add_handler(Server::get_users)
114 .add_handler(Server::join_channel)
115 .add_handler(Server::leave_channel)
116 .add_handler(Server::send_channel_message)
117 .add_handler(Server::get_channel_messages);
118
119 Arc::new(server)
120 }
121
122 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
123 where
124 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
125 Fut: 'static + Send + Future<Output = tide::Result<()>>,
126 M: EnvelopedMessage,
127 {
128 let prev_handler = self.handlers.insert(
129 TypeId::of::<M>(),
130 Box::new(move |server, envelope| {
131 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
132 (handler)(server, *envelope).boxed()
133 }),
134 );
135 if prev_handler.is_some() {
136 panic!("registered a handler for the same message twice");
137 }
138 self
139 }
140
141 pub fn handle_connection(
142 self: &Arc<Self>,
143 connection: Connection,
144 addr: String,
145 user_id: UserId,
146 ) -> impl Future<Output = ()> {
147 let this = self.clone();
148 async move {
149 let (connection_id, handle_io, mut incoming_rx) =
150 this.peer.add_connection(connection).await;
151 this.add_connection(connection_id, user_id).await;
152
153 let handle_io = handle_io.fuse();
154 futures::pin_mut!(handle_io);
155 loop {
156 let next_message = incoming_rx.recv().fuse();
157 futures::pin_mut!(next_message);
158 futures::select_biased! {
159 message = next_message => {
160 if let Some(message) = message {
161 let start_time = Instant::now();
162 log::info!("RPC message received: {}", message.payload_type_name());
163 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
164 if let Err(err) = (handler)(this.clone(), message).await {
165 log::error!("error handling message: {:?}", err);
166 } else {
167 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
168 }
169
170 if let Some(mut notifications) = this.notifications.clone() {
171 let _ = notifications.send(()).await;
172 }
173 } else {
174 log::warn!("unhandled message: {}", message.payload_type_name());
175 }
176 } else {
177 log::info!("rpc connection closed {:?}", addr);
178 break;
179 }
180 }
181 handle_io = handle_io => {
182 if let Err(err) = handle_io {
183 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
184 }
185 break;
186 }
187 }
188 }
189
190 if let Err(err) = this.sign_out(connection_id).await {
191 log::error!("error signing out connection {:?} - {:?}", addr, err);
192 }
193 }
194 }
195
196 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
197 self.peer.disconnect(connection_id).await;
198 self.remove_connection(connection_id).await?;
199 Ok(())
200 }
201
202 // Add a new connection associated with a given user.
203 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
204 let mut state = self.state.write().await;
205 state.connections.insert(
206 connection_id,
207 ConnectionState {
208 user_id,
209 worktrees: Default::default(),
210 channels: Default::default(),
211 },
212 );
213 state
214 .connections_by_user_id
215 .entry(user_id)
216 .or_default()
217 .insert(connection_id);
218 }
219
220 // Remove the given connection and its association with any worktrees.
221 async fn remove_connection(
222 self: &Arc<Server>,
223 connection_id: ConnectionId,
224 ) -> tide::Result<()> {
225 let mut worktree_ids = Vec::new();
226 let mut state = self.state.write().await;
227 if let Some(connection) = state.connections.remove(&connection_id) {
228 worktree_ids = connection.worktrees.into_iter().collect();
229
230 for channel_id in connection.channels {
231 if let Some(channel) = state.channels.get_mut(&channel_id) {
232 channel.connection_ids.remove(&connection_id);
233 }
234 }
235
236 let user_connections = state
237 .connections_by_user_id
238 .get_mut(&connection.user_id)
239 .unwrap();
240 user_connections.remove(&connection_id);
241 if user_connections.is_empty() {
242 state.connections_by_user_id.remove(&connection.user_id);
243 }
244 }
245
246 drop(state);
247 for worktree_id in worktree_ids {
248 self.close_worktree(worktree_id, connection_id).await?;
249 }
250
251 Ok(())
252 }
253
254 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
255 self.peer.respond(request.receipt(), proto::Ack {}).await?;
256 Ok(())
257 }
258
259 async fn open_worktree(
260 self: Arc<Server>,
261 request: TypedEnvelope<proto::OpenWorktree>,
262 ) -> tide::Result<()> {
263 let receipt = request.receipt();
264 let host_user_id = self
265 .state
266 .read()
267 .await
268 .user_id_for_connection(request.sender_id)?;
269
270 let mut collaborator_user_ids = Vec::new();
271 for github_login in request.payload.collaborator_logins {
272 match self.app_state.db.create_user(&github_login, false).await {
273 Ok(collaborator_user_id) => {
274 if collaborator_user_id != host_user_id {
275 collaborator_user_ids.push(collaborator_user_id);
276 }
277 }
278 Err(err) => {
279 let message = err.to_string();
280 self.peer
281 .respond_with_error(receipt, proto::Error { message })
282 .await?;
283 return Ok(());
284 }
285 }
286 }
287
288 let worktree_id;
289 let mut user_ids;
290 {
291 let mut state = self.state.write().await;
292 worktree_id = state.add_worktree(Worktree {
293 host_connection_id: request.sender_id,
294 collaborator_user_ids: collaborator_user_ids.clone(),
295 root_name: request.payload.root_name,
296 share: None,
297 });
298 user_ids = collaborator_user_ids;
299 user_ids.push(host_user_id);
300 }
301
302 self.peer
303 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
304 .await?;
305 self.update_collaborators_for_users(&user_ids).await?;
306
307 Ok(())
308 }
309
310 async fn share_worktree(
311 self: Arc<Server>,
312 mut request: TypedEnvelope<proto::ShareWorktree>,
313 ) -> tide::Result<()> {
314 let host_user_id = self
315 .state
316 .read()
317 .await
318 .user_id_for_connection(request.sender_id)?;
319 let worktree = request
320 .payload
321 .worktree
322 .as_mut()
323 .ok_or_else(|| anyhow!("missing worktree"))?;
324 let entries = mem::take(&mut worktree.entries)
325 .into_iter()
326 .map(|entry| (entry.id, entry))
327 .collect();
328
329 let mut state = self.state.write().await;
330 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
331 worktree.share = Some(WorktreeShare {
332 guest_connection_ids: Default::default(),
333 active_replica_ids: Default::default(),
334 entries,
335 });
336
337 let mut user_ids = worktree.collaborator_user_ids.clone();
338 user_ids.push(host_user_id);
339
340 drop(state);
341 self.peer
342 .respond(request.receipt(), proto::ShareWorktreeResponse {})
343 .await?;
344 self.update_collaborators_for_users(&user_ids).await?;
345 } else {
346 self.peer
347 .respond_with_error(
348 request.receipt(),
349 proto::Error {
350 message: "no such worktree".to_string(),
351 },
352 )
353 .await?;
354 }
355 Ok(())
356 }
357
358 async fn unshare_worktree(
359 self: Arc<Server>,
360 request: TypedEnvelope<proto::UnshareWorktree>,
361 ) -> tide::Result<()> {
362 let worktree_id = request.payload.worktree_id;
363 let host_user_id = self
364 .state
365 .read()
366 .await
367 .user_id_for_connection(request.sender_id)?;
368
369 let connection_ids;
370 let mut user_ids;
371 {
372 let mut state = self.state.write().await;
373 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
374 if worktree.host_connection_id != request.sender_id {
375 return Err(anyhow!("no such worktree"))?;
376 }
377
378 connection_ids = worktree.connection_ids();
379 user_ids = worktree.collaborator_user_ids.clone();
380 user_ids.push(host_user_id);
381 worktree.share.take();
382 for connection_id in &connection_ids {
383 if let Some(connection) = state.connections.get_mut(connection_id) {
384 connection.worktrees.remove(&worktree_id);
385 }
386 }
387 }
388
389 broadcast(request.sender_id, connection_ids, |conn_id| {
390 self.peer
391 .send(conn_id, proto::UnshareWorktree { worktree_id })
392 })
393 .await?;
394 self.update_collaborators_for_users(&user_ids).await?;
395
396 Ok(())
397 }
398
399 async fn join_worktree(
400 self: Arc<Server>,
401 request: TypedEnvelope<proto::JoinWorktree>,
402 ) -> tide::Result<()> {
403 let worktree_id = request.payload.worktree_id;
404 let user_id = self
405 .state
406 .read()
407 .await
408 .user_id_for_connection(request.sender_id)?;
409
410 let response;
411 let connection_ids;
412 let mut user_ids;
413 let mut state = self.state.write().await;
414 match state.join_worktree(request.sender_id, user_id, worktree_id) {
415 Ok((peer_replica_id, worktree)) => {
416 let share = worktree.share()?;
417 let peer_count = share.guest_connection_ids.len();
418 let mut peers = Vec::with_capacity(peer_count);
419 peers.push(proto::Peer {
420 peer_id: worktree.host_connection_id.0,
421 replica_id: 0,
422 });
423 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
424 if *peer_conn_id != request.sender_id {
425 peers.push(proto::Peer {
426 peer_id: peer_conn_id.0,
427 replica_id: *peer_replica_id as u32,
428 });
429 }
430 }
431 response = proto::JoinWorktreeResponse {
432 worktree: Some(proto::Worktree {
433 id: worktree_id,
434 root_name: worktree.root_name.clone(),
435 entries: share.entries.values().cloned().collect(),
436 }),
437 replica_id: peer_replica_id as u32,
438 peers,
439 };
440
441 let host_connection_id = worktree.host_connection_id;
442 connection_ids = worktree.connection_ids();
443 user_ids = worktree.collaborator_user_ids.clone();
444 user_ids.push(state.user_id_for_connection(host_connection_id)?);
445 }
446 Err(error) => {
447 self.peer
448 .respond_with_error(
449 request.receipt(),
450 proto::Error {
451 message: error.to_string(),
452 },
453 )
454 .await?;
455 return Ok(());
456 }
457 }
458
459 drop(state);
460 broadcast(request.sender_id, connection_ids, |conn_id| {
461 self.peer.send(
462 conn_id,
463 proto::AddPeer {
464 worktree_id,
465 peer: Some(proto::Peer {
466 peer_id: request.sender_id.0,
467 replica_id: response.replica_id,
468 }),
469 },
470 )
471 })
472 .await?;
473 self.peer.respond(request.receipt(), response).await?;
474 self.update_collaborators_for_users(&user_ids).await?;
475
476 Ok(())
477 }
478
479 async fn handle_close_worktree(
480 self: Arc<Server>,
481 request: TypedEnvelope<proto::CloseWorktree>,
482 ) -> tide::Result<()> {
483 self.close_worktree(request.payload.worktree_id, request.sender_id)
484 .await
485 }
486
487 async fn close_worktree(
488 self: &Arc<Server>,
489 worktree_id: u64,
490 sender_conn_id: ConnectionId,
491 ) -> tide::Result<()> {
492 let connection_ids;
493 let mut user_ids;
494
495 let mut is_host = false;
496 let mut is_guest = false;
497 {
498 let mut state = self.state.write().await;
499 let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
500 let host_connection_id = worktree.host_connection_id;
501 connection_ids = worktree.connection_ids();
502 user_ids = worktree.collaborator_user_ids.clone();
503
504 if worktree.host_connection_id == sender_conn_id {
505 is_host = true;
506 state.remove_worktree(worktree_id);
507 } else {
508 let share = worktree.share_mut()?;
509 if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
510 is_guest = true;
511 share.active_replica_ids.remove(&replica_id);
512 }
513 }
514
515 user_ids.push(state.user_id_for_connection(host_connection_id)?);
516 }
517
518 if is_host {
519 broadcast(sender_conn_id, connection_ids, |conn_id| {
520 self.peer
521 .send(conn_id, proto::UnshareWorktree { worktree_id })
522 })
523 .await?;
524 } else if is_guest {
525 broadcast(sender_conn_id, connection_ids, |conn_id| {
526 self.peer.send(
527 conn_id,
528 proto::RemovePeer {
529 worktree_id,
530 peer_id: sender_conn_id.0,
531 },
532 )
533 })
534 .await?
535 }
536 self.update_collaborators_for_users(&user_ids).await?;
537 Ok(())
538 }
539
540 async fn update_worktree(
541 self: Arc<Server>,
542 request: TypedEnvelope<proto::UpdateWorktree>,
543 ) -> tide::Result<()> {
544 {
545 let mut state = self.state.write().await;
546 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
547 let share = worktree.share_mut()?;
548
549 for entry_id in &request.payload.removed_entries {
550 share.entries.remove(&entry_id);
551 }
552
553 for entry in &request.payload.updated_entries {
554 share.entries.insert(entry.id, entry.clone());
555 }
556 }
557
558 self.broadcast_in_worktree(request.payload.worktree_id, &request)
559 .await?;
560 Ok(())
561 }
562
563 async fn open_buffer(
564 self: Arc<Server>,
565 request: TypedEnvelope<proto::OpenBuffer>,
566 ) -> tide::Result<()> {
567 let receipt = request.receipt();
568 let worktree_id = request.payload.worktree_id;
569 let host_connection_id = self
570 .state
571 .read()
572 .await
573 .read_worktree(worktree_id, request.sender_id)?
574 .host_connection_id;
575
576 let response = self
577 .peer
578 .forward_request(request.sender_id, host_connection_id, request.payload)
579 .await?;
580 self.peer.respond(receipt, response).await?;
581 Ok(())
582 }
583
584 async fn close_buffer(
585 self: Arc<Server>,
586 request: TypedEnvelope<proto::CloseBuffer>,
587 ) -> tide::Result<()> {
588 let host_connection_id = self
589 .state
590 .read()
591 .await
592 .read_worktree(request.payload.worktree_id, request.sender_id)?
593 .host_connection_id;
594
595 self.peer
596 .forward_send(request.sender_id, host_connection_id, request.payload)
597 .await?;
598
599 Ok(())
600 }
601
602 async fn save_buffer(
603 self: Arc<Server>,
604 request: TypedEnvelope<proto::SaveBuffer>,
605 ) -> tide::Result<()> {
606 let host;
607 let guests;
608 {
609 let state = self.state.read().await;
610 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
611 host = worktree.host_connection_id;
612 guests = worktree
613 .share()?
614 .guest_connection_ids
615 .keys()
616 .copied()
617 .collect::<Vec<_>>();
618 }
619
620 let sender = request.sender_id;
621 let receipt = request.receipt();
622 let response = self
623 .peer
624 .forward_request(sender, host, request.payload.clone())
625 .await?;
626
627 broadcast(host, guests, |conn_id| {
628 let response = response.clone();
629 let peer = &self.peer;
630 async move {
631 if conn_id == sender {
632 peer.respond(receipt, response).await
633 } else {
634 peer.forward_send(host, conn_id, response).await
635 }
636 }
637 })
638 .await?;
639
640 Ok(())
641 }
642
643 async fn update_buffer(
644 self: Arc<Server>,
645 request: TypedEnvelope<proto::UpdateBuffer>,
646 ) -> tide::Result<()> {
647 self.broadcast_in_worktree(request.payload.worktree_id, &request)
648 .await?;
649 self.peer.respond(request.receipt(), proto::Ack {}).await?;
650 Ok(())
651 }
652
653 async fn buffer_saved(
654 self: Arc<Server>,
655 request: TypedEnvelope<proto::BufferSaved>,
656 ) -> tide::Result<()> {
657 self.broadcast_in_worktree(request.payload.worktree_id, &request)
658 .await
659 }
660
661 async fn get_channels(
662 self: Arc<Server>,
663 request: TypedEnvelope<proto::GetChannels>,
664 ) -> tide::Result<()> {
665 let user_id = self
666 .state
667 .read()
668 .await
669 .user_id_for_connection(request.sender_id)?;
670 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
671 self.peer
672 .respond(
673 request.receipt(),
674 proto::GetChannelsResponse {
675 channels: channels
676 .into_iter()
677 .map(|chan| proto::Channel {
678 id: chan.id.to_proto(),
679 name: chan.name,
680 })
681 .collect(),
682 },
683 )
684 .await?;
685 Ok(())
686 }
687
688 async fn get_users(
689 self: Arc<Server>,
690 request: TypedEnvelope<proto::GetUsers>,
691 ) -> tide::Result<()> {
692 let user_id = self
693 .state
694 .read()
695 .await
696 .user_id_for_connection(request.sender_id)?;
697 let receipt = request.receipt();
698 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
699 let users = self
700 .app_state
701 .db
702 .get_users_by_ids(user_id, user_ids)
703 .await?
704 .into_iter()
705 .map(|user| proto::User {
706 id: user.id.to_proto(),
707 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
708 github_login: user.github_login,
709 })
710 .collect();
711 self.peer
712 .respond(receipt, proto::GetUsersResponse { users })
713 .await?;
714 Ok(())
715 }
716
717 async fn update_collaborators_for_users<'a>(
718 self: &Arc<Server>,
719 user_ids: impl IntoIterator<Item = &'a UserId>,
720 ) -> tide::Result<()> {
721 let mut send_futures = Vec::new();
722
723 let state = self.state.read().await;
724 for user_id in user_ids {
725 let mut collaborators = HashMap::new();
726 for worktree_id in state
727 .visible_worktrees_by_user_id
728 .get(&user_id)
729 .unwrap_or(&HashSet::new())
730 {
731 let worktree = &state.worktrees[worktree_id];
732
733 let mut participants = HashSet::new();
734 if let Ok(share) = worktree.share() {
735 for guest_connection_id in share.guest_connection_ids.keys() {
736 let user_id = state.user_id_for_connection(*guest_connection_id)?;
737 participants.insert(user_id.to_proto());
738 }
739 }
740
741 let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
742 let host =
743 collaborators
744 .entry(host_user_id)
745 .or_insert_with(|| proto::Collaborator {
746 user_id: host_user_id.to_proto(),
747 worktrees: Vec::new(),
748 });
749 host.worktrees.push(proto::WorktreeMetadata {
750 root_name: worktree.root_name.clone(),
751 is_shared: worktree.share().is_ok(),
752 participants: participants.into_iter().collect(),
753 });
754 }
755
756 let collaborators = collaborators.into_values().collect::<Vec<_>>();
757 for connection_id in state.user_connection_ids(*user_id) {
758 send_futures.push(self.peer.send(
759 connection_id,
760 proto::UpdateCollaborators {
761 collaborators: collaborators.clone(),
762 },
763 ));
764 }
765 }
766
767 drop(state);
768 futures::future::try_join_all(send_futures).await?;
769
770 Ok(())
771 }
772
773 async fn join_channel(
774 self: Arc<Self>,
775 request: TypedEnvelope<proto::JoinChannel>,
776 ) -> tide::Result<()> {
777 let user_id = self
778 .state
779 .read()
780 .await
781 .user_id_for_connection(request.sender_id)?;
782 let channel_id = ChannelId::from_proto(request.payload.channel_id);
783 if !self
784 .app_state
785 .db
786 .can_user_access_channel(user_id, channel_id)
787 .await?
788 {
789 Err(anyhow!("access denied"))?;
790 }
791
792 self.state
793 .write()
794 .await
795 .join_channel(request.sender_id, channel_id);
796 let messages = self
797 .app_state
798 .db
799 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
800 .await?
801 .into_iter()
802 .map(|msg| proto::ChannelMessage {
803 id: msg.id.to_proto(),
804 body: msg.body,
805 timestamp: msg.sent_at.unix_timestamp() as u64,
806 sender_id: msg.sender_id.to_proto(),
807 nonce: Some(msg.nonce.as_u128().into()),
808 })
809 .collect::<Vec<_>>();
810 self.peer
811 .respond(
812 request.receipt(),
813 proto::JoinChannelResponse {
814 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
815 messages,
816 },
817 )
818 .await?;
819 Ok(())
820 }
821
822 async fn leave_channel(
823 self: Arc<Self>,
824 request: TypedEnvelope<proto::LeaveChannel>,
825 ) -> tide::Result<()> {
826 let user_id = self
827 .state
828 .read()
829 .await
830 .user_id_for_connection(request.sender_id)?;
831 let channel_id = ChannelId::from_proto(request.payload.channel_id);
832 if !self
833 .app_state
834 .db
835 .can_user_access_channel(user_id, channel_id)
836 .await?
837 {
838 Err(anyhow!("access denied"))?;
839 }
840
841 self.state
842 .write()
843 .await
844 .leave_channel(request.sender_id, channel_id);
845
846 Ok(())
847 }
848
849 async fn send_channel_message(
850 self: Arc<Self>,
851 request: TypedEnvelope<proto::SendChannelMessage>,
852 ) -> tide::Result<()> {
853 let receipt = request.receipt();
854 let channel_id = ChannelId::from_proto(request.payload.channel_id);
855 let user_id;
856 let connection_ids;
857 {
858 let state = self.state.read().await;
859 user_id = state.user_id_for_connection(request.sender_id)?;
860 if let Some(channel) = state.channels.get(&channel_id) {
861 connection_ids = channel.connection_ids();
862 } else {
863 return Ok(());
864 }
865 }
866
867 // Validate the message body.
868 let body = request.payload.body.trim().to_string();
869 if body.len() > MAX_MESSAGE_LEN {
870 self.peer
871 .respond_with_error(
872 receipt,
873 proto::Error {
874 message: "message is too long".to_string(),
875 },
876 )
877 .await?;
878 return Ok(());
879 }
880 if body.is_empty() {
881 self.peer
882 .respond_with_error(
883 receipt,
884 proto::Error {
885 message: "message can't be blank".to_string(),
886 },
887 )
888 .await?;
889 return Ok(());
890 }
891
892 let timestamp = OffsetDateTime::now_utc();
893 let nonce = if let Some(nonce) = request.payload.nonce {
894 nonce
895 } else {
896 self.peer
897 .respond_with_error(
898 receipt,
899 proto::Error {
900 message: "nonce can't be blank".to_string(),
901 },
902 )
903 .await?;
904 return Ok(());
905 };
906
907 let message_id = self
908 .app_state
909 .db
910 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
911 .await?
912 .to_proto();
913 let message = proto::ChannelMessage {
914 sender_id: user_id.to_proto(),
915 id: message_id,
916 body,
917 timestamp: timestamp.unix_timestamp() as u64,
918 nonce: Some(nonce),
919 };
920 broadcast(request.sender_id, connection_ids, |conn_id| {
921 self.peer.send(
922 conn_id,
923 proto::ChannelMessageSent {
924 channel_id: channel_id.to_proto(),
925 message: Some(message.clone()),
926 },
927 )
928 })
929 .await?;
930 self.peer
931 .respond(
932 receipt,
933 proto::SendChannelMessageResponse {
934 message: Some(message),
935 },
936 )
937 .await?;
938 Ok(())
939 }
940
941 async fn get_channel_messages(
942 self: Arc<Self>,
943 request: TypedEnvelope<proto::GetChannelMessages>,
944 ) -> tide::Result<()> {
945 let user_id = self
946 .state
947 .read()
948 .await
949 .user_id_for_connection(request.sender_id)?;
950 let channel_id = ChannelId::from_proto(request.payload.channel_id);
951 if !self
952 .app_state
953 .db
954 .can_user_access_channel(user_id, channel_id)
955 .await?
956 {
957 Err(anyhow!("access denied"))?;
958 }
959
960 let messages = self
961 .app_state
962 .db
963 .get_channel_messages(
964 channel_id,
965 MESSAGE_COUNT_PER_PAGE,
966 Some(MessageId::from_proto(request.payload.before_message_id)),
967 )
968 .await?
969 .into_iter()
970 .map(|msg| proto::ChannelMessage {
971 id: msg.id.to_proto(),
972 body: msg.body,
973 timestamp: msg.sent_at.unix_timestamp() as u64,
974 sender_id: msg.sender_id.to_proto(),
975 nonce: Some(msg.nonce.as_u128().into()),
976 })
977 .collect::<Vec<_>>();
978 self.peer
979 .respond(
980 request.receipt(),
981 proto::GetChannelMessagesResponse {
982 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
983 messages,
984 },
985 )
986 .await?;
987 Ok(())
988 }
989
990 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
991 &self,
992 worktree_id: u64,
993 message: &TypedEnvelope<T>,
994 ) -> tide::Result<()> {
995 let connection_ids = self
996 .state
997 .read()
998 .await
999 .read_worktree(worktree_id, message.sender_id)?
1000 .connection_ids();
1001
1002 broadcast(message.sender_id, connection_ids, |conn_id| {
1003 self.peer
1004 .forward_send(message.sender_id, conn_id, message.payload.clone())
1005 })
1006 .await?;
1007
1008 Ok(())
1009 }
1010}
1011
1012pub async fn broadcast<F, T>(
1013 sender_id: ConnectionId,
1014 receiver_ids: Vec<ConnectionId>,
1015 mut f: F,
1016) -> anyhow::Result<()>
1017where
1018 F: FnMut(ConnectionId) -> T,
1019 T: Future<Output = anyhow::Result<()>>,
1020{
1021 let futures = receiver_ids
1022 .into_iter()
1023 .filter(|id| *id != sender_id)
1024 .map(|id| f(id));
1025 futures::future::try_join_all(futures).await?;
1026 Ok(())
1027}
1028
1029impl ServerState {
1030 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1031 if let Some(connection) = self.connections.get_mut(&connection_id) {
1032 connection.channels.insert(channel_id);
1033 self.channels
1034 .entry(channel_id)
1035 .or_default()
1036 .connection_ids
1037 .insert(connection_id);
1038 }
1039 }
1040
1041 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1042 if let Some(connection) = self.connections.get_mut(&connection_id) {
1043 connection.channels.remove(&channel_id);
1044 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1045 entry.get_mut().connection_ids.remove(&connection_id);
1046 if entry.get_mut().connection_ids.is_empty() {
1047 entry.remove();
1048 }
1049 }
1050 }
1051 }
1052
1053 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1054 Ok(self
1055 .connections
1056 .get(&connection_id)
1057 .ok_or_else(|| anyhow!("unknown connection"))?
1058 .user_id)
1059 }
1060
1061 fn user_connection_ids<'a>(
1062 &'a self,
1063 user_id: UserId,
1064 ) -> impl 'a + Iterator<Item = ConnectionId> {
1065 self.connections_by_user_id
1066 .get(&user_id)
1067 .into_iter()
1068 .flatten()
1069 .copied()
1070 }
1071
1072 // Add the given connection as a guest of the given worktree
1073 fn join_worktree(
1074 &mut self,
1075 connection_id: ConnectionId,
1076 user_id: UserId,
1077 worktree_id: u64,
1078 ) -> tide::Result<(ReplicaId, &Worktree)> {
1079 let connection = self
1080 .connections
1081 .get_mut(&connection_id)
1082 .ok_or_else(|| anyhow!("no such connection"))?;
1083 let worktree = self
1084 .worktrees
1085 .get_mut(&worktree_id)
1086 .ok_or_else(|| anyhow!("no such worktree"))?;
1087 if !worktree.collaborator_user_ids.contains(&user_id) {
1088 Err(anyhow!("no such worktree"))?;
1089 }
1090
1091 let share = worktree.share_mut()?;
1092 connection.worktrees.insert(worktree_id);
1093
1094 let mut replica_id = 1;
1095 while share.active_replica_ids.contains(&replica_id) {
1096 replica_id += 1;
1097 }
1098 share.active_replica_ids.insert(replica_id);
1099 share.guest_connection_ids.insert(connection_id, replica_id);
1100 return Ok((replica_id, worktree));
1101 }
1102
1103 fn read_worktree(
1104 &self,
1105 worktree_id: u64,
1106 connection_id: ConnectionId,
1107 ) -> tide::Result<&Worktree> {
1108 let worktree = self
1109 .worktrees
1110 .get(&worktree_id)
1111 .ok_or_else(|| anyhow!("worktree not found"))?;
1112
1113 if worktree.host_connection_id == connection_id
1114 || worktree
1115 .share()?
1116 .guest_connection_ids
1117 .contains_key(&connection_id)
1118 {
1119 Ok(worktree)
1120 } else {
1121 Err(anyhow!(
1122 "{} is not a member of worktree {}",
1123 connection_id,
1124 worktree_id
1125 ))?
1126 }
1127 }
1128
1129 fn write_worktree(
1130 &mut self,
1131 worktree_id: u64,
1132 connection_id: ConnectionId,
1133 ) -> tide::Result<&mut Worktree> {
1134 let worktree = self
1135 .worktrees
1136 .get_mut(&worktree_id)
1137 .ok_or_else(|| anyhow!("worktree not found"))?;
1138
1139 if worktree.host_connection_id == connection_id
1140 || worktree.share.as_ref().map_or(false, |share| {
1141 share.guest_connection_ids.contains_key(&connection_id)
1142 })
1143 {
1144 Ok(worktree)
1145 } else {
1146 Err(anyhow!(
1147 "{} is not a member of worktree {}",
1148 connection_id,
1149 worktree_id
1150 ))?
1151 }
1152 }
1153
1154 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1155 let worktree_id = self.next_worktree_id;
1156 for collaborator_user_id in &worktree.collaborator_user_ids {
1157 self.visible_worktrees_by_user_id
1158 .entry(*collaborator_user_id)
1159 .or_default()
1160 .insert(worktree_id);
1161 }
1162 self.next_worktree_id += 1;
1163 self.worktrees.insert(worktree_id, worktree);
1164 worktree_id
1165 }
1166
1167 fn remove_worktree(&mut self, worktree_id: u64) {
1168 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1169 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1170 connection.worktrees.remove(&worktree_id);
1171 }
1172 if let Some(share) = worktree.share {
1173 for connection_id in share.guest_connection_ids.keys() {
1174 if let Some(connection) = self.connections.get_mut(connection_id) {
1175 connection.worktrees.remove(&worktree_id);
1176 }
1177 }
1178 }
1179 for collaborator_user_id in worktree.collaborator_user_ids {
1180 if let Some(visible_worktrees) = self
1181 .visible_worktrees_by_user_id
1182 .get_mut(&collaborator_user_id)
1183 {
1184 visible_worktrees.remove(&worktree_id);
1185 }
1186 }
1187 }
1188}
1189
1190impl Worktree {
1191 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1192 if let Some(share) = &self.share {
1193 share
1194 .guest_connection_ids
1195 .keys()
1196 .copied()
1197 .chain(Some(self.host_connection_id))
1198 .collect()
1199 } else {
1200 vec![self.host_connection_id]
1201 }
1202 }
1203
1204 fn share(&self) -> tide::Result<&WorktreeShare> {
1205 Ok(self
1206 .share
1207 .as_ref()
1208 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1209 }
1210
1211 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1212 Ok(self
1213 .share
1214 .as_mut()
1215 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1216 }
1217}
1218
1219impl Channel {
1220 fn connection_ids(&self) -> Vec<ConnectionId> {
1221 self.connection_ids.iter().copied().collect()
1222 }
1223}
1224
1225pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1226 let server = Server::new(app.state().clone(), rpc.clone(), None);
1227 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1228 let user_id = request.ext::<UserId>().copied();
1229 let server = server.clone();
1230 async move {
1231 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1232
1233 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1234 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1235 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1236
1237 if !upgrade_requested {
1238 return Ok(Response::new(StatusCode::UpgradeRequired));
1239 }
1240
1241 let header = match request.header("Sec-Websocket-Key") {
1242 Some(h) => h.as_str(),
1243 None => return Err(anyhow!("expected sec-websocket-key"))?,
1244 };
1245
1246 let mut response = Response::new(StatusCode::SwitchingProtocols);
1247 response.insert_header(UPGRADE, "websocket");
1248 response.insert_header(CONNECTION, "Upgrade");
1249 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1250 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1251 response.insert_header("Sec-Websocket-Version", "13");
1252
1253 let http_res: &mut tide::http::Response = response.as_mut();
1254 let upgrade_receiver = http_res.recv_upgrade().await;
1255 let addr = request.remote().unwrap_or("unknown").to_string();
1256 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1257 task::spawn(async move {
1258 if let Some(stream) = upgrade_receiver.await {
1259 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1260 }
1261 });
1262
1263 Ok(response)
1264 }
1265 });
1266}
1267
1268fn header_contains_ignore_case<T>(
1269 request: &tide::Request<T>,
1270 header_name: HeaderName,
1271 value: &str,
1272) -> bool {
1273 request
1274 .header(header_name)
1275 .map(|h| {
1276 h.as_str()
1277 .split(',')
1278 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1279 })
1280 .unwrap_or(false)
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285 use super::*;
1286 use crate::{
1287 auth,
1288 db::{tests::TestDb, UserId},
1289 github, AppState, Config,
1290 };
1291 use async_std::{sync::RwLockReadGuard, task};
1292 use gpui::TestAppContext;
1293 use parking_lot::Mutex;
1294 use postage::{mpsc, watch};
1295 use serde_json::json;
1296 use sqlx::types::time::OffsetDateTime;
1297 use std::{
1298 path::Path,
1299 sync::{
1300 atomic::{AtomicBool, Ordering::SeqCst},
1301 Arc,
1302 },
1303 time::Duration,
1304 };
1305 use zed::{
1306 channel::{Channel, ChannelDetails, ChannelList},
1307 editor::{Editor, Insert},
1308 fs::{FakeFs, Fs as _},
1309 language::LanguageRegistry,
1310 rpc::{self, Client, Credentials, EstablishConnectionError},
1311 settings,
1312 test::FakeHttpClient,
1313 user::UserStore,
1314 worktree::Worktree,
1315 };
1316 use zrpc::Peer;
1317
1318 #[gpui::test]
1319 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1320 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1321 let settings = cx_b.read(settings::test).1;
1322 let lang_registry = Arc::new(LanguageRegistry::new());
1323
1324 // Connect to a server as 2 clients.
1325 let mut server = TestServer::start().await;
1326 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1327 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1328
1329 cx_a.foreground().forbid_parking();
1330
1331 // Share a local worktree as client A
1332 let fs = Arc::new(FakeFs::new());
1333 fs.insert_tree(
1334 "/a",
1335 json!({
1336 ".zed.toml": r#"collaborators = ["user_b"]"#,
1337 "a.txt": "a-contents",
1338 "b.txt": "b-contents",
1339 }),
1340 )
1341 .await;
1342 let worktree_a = Worktree::open_local(
1343 client_a.clone(),
1344 "/a".as_ref(),
1345 fs,
1346 lang_registry.clone(),
1347 &mut cx_a.to_async(),
1348 )
1349 .await
1350 .unwrap();
1351 worktree_a
1352 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1353 .await;
1354 let worktree_id = worktree_a
1355 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1356 .await
1357 .unwrap();
1358
1359 // Join that worktree as client B, and see that a guest has joined as client A.
1360 let worktree_b = Worktree::open_remote(
1361 client_b.clone(),
1362 worktree_id,
1363 lang_registry.clone(),
1364 &mut cx_b.to_async(),
1365 )
1366 .await
1367 .unwrap();
1368 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1369 worktree_a
1370 .condition(&cx_a, |tree, _| {
1371 tree.peers()
1372 .values()
1373 .any(|replica_id| *replica_id == replica_id_b)
1374 })
1375 .await;
1376
1377 // Open the same file as client B and client A.
1378 let buffer_b = worktree_b
1379 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1380 .await
1381 .unwrap();
1382 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1383 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1384 let buffer_a = worktree_a
1385 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1386 .await
1387 .unwrap();
1388
1389 // Create a selection set as client B and see that selection set as client A.
1390 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1391 buffer_a
1392 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1393 .await;
1394
1395 // Edit the buffer as client B and see that edit as client A.
1396 editor_b.update(&mut cx_b, |editor, cx| {
1397 editor.insert(&Insert("ok, ".into()), cx)
1398 });
1399 buffer_a
1400 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1401 .await;
1402
1403 // Remove the selection set as client B, see those selections disappear as client A.
1404 cx_b.update(move |_| drop(editor_b));
1405 buffer_a
1406 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1407 .await;
1408
1409 // Close the buffer as client A, see that the buffer is closed.
1410 cx_a.update(move |_| drop(buffer_a));
1411 worktree_a
1412 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1413 .await;
1414
1415 // Dropping the worktree removes client B from client A's peers.
1416 cx_b.update(move |_| drop(worktree_b));
1417 worktree_a
1418 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1419 .await;
1420 }
1421
1422 #[gpui::test]
1423 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1424 mut cx_a: TestAppContext,
1425 mut cx_b: TestAppContext,
1426 mut cx_c: TestAppContext,
1427 ) {
1428 cx_a.foreground().forbid_parking();
1429 let lang_registry = Arc::new(LanguageRegistry::new());
1430
1431 // Connect to a server as 3 clients.
1432 let mut server = TestServer::start().await;
1433 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1434 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1435 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1436
1437 let fs = Arc::new(FakeFs::new());
1438
1439 // Share a worktree as client A.
1440 fs.insert_tree(
1441 "/a",
1442 json!({
1443 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1444 "file1": "",
1445 "file2": ""
1446 }),
1447 )
1448 .await;
1449
1450 let worktree_a = Worktree::open_local(
1451 client_a.clone(),
1452 "/a".as_ref(),
1453 fs.clone(),
1454 lang_registry.clone(),
1455 &mut cx_a.to_async(),
1456 )
1457 .await
1458 .unwrap();
1459 worktree_a
1460 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1461 .await;
1462 let worktree_id = worktree_a
1463 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1464 .await
1465 .unwrap();
1466
1467 // Join that worktree as clients B and C.
1468 let worktree_b = Worktree::open_remote(
1469 client_b.clone(),
1470 worktree_id,
1471 lang_registry.clone(),
1472 &mut cx_b.to_async(),
1473 )
1474 .await
1475 .unwrap();
1476 let worktree_c = Worktree::open_remote(
1477 client_c.clone(),
1478 worktree_id,
1479 lang_registry.clone(),
1480 &mut cx_c.to_async(),
1481 )
1482 .await
1483 .unwrap();
1484
1485 // Open and edit a buffer as both guests B and C.
1486 let buffer_b = worktree_b
1487 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1488 .await
1489 .unwrap();
1490 let buffer_c = worktree_c
1491 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1492 .await
1493 .unwrap();
1494 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1495 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1496
1497 // Open and edit that buffer as the host.
1498 let buffer_a = worktree_a
1499 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1500 .await
1501 .unwrap();
1502
1503 buffer_a
1504 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1505 .await;
1506 buffer_a.update(&mut cx_a, |buf, cx| {
1507 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1508 });
1509
1510 // Wait for edits to propagate
1511 buffer_a
1512 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1513 .await;
1514 buffer_b
1515 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1516 .await;
1517 buffer_c
1518 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1519 .await;
1520
1521 // Edit the buffer as the host and concurrently save as guest B.
1522 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1523 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1524 save_b.await.unwrap();
1525 assert_eq!(
1526 fs.load("/a/file1".as_ref()).await.unwrap(),
1527 "hi-a, i-am-c, i-am-b, i-am-a"
1528 );
1529 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1530 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1531 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1532
1533 // Make changes on host's file system, see those changes on the guests.
1534 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1535 .await
1536 .unwrap();
1537 fs.insert_file(Path::new("/a/file4"), "4".into())
1538 .await
1539 .unwrap();
1540
1541 worktree_b
1542 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1543 .await;
1544 worktree_c
1545 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1546 .await;
1547 worktree_b.read_with(&cx_b, |tree, _| {
1548 assert_eq!(
1549 tree.paths()
1550 .map(|p| p.to_string_lossy())
1551 .collect::<Vec<_>>(),
1552 &[".zed.toml", "file1", "file3", "file4"]
1553 )
1554 });
1555 worktree_c.read_with(&cx_c, |tree, _| {
1556 assert_eq!(
1557 tree.paths()
1558 .map(|p| p.to_string_lossy())
1559 .collect::<Vec<_>>(),
1560 &[".zed.toml", "file1", "file3", "file4"]
1561 )
1562 });
1563 }
1564
1565 #[gpui::test]
1566 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1567 cx_a.foreground().forbid_parking();
1568 let lang_registry = Arc::new(LanguageRegistry::new());
1569
1570 // Connect to a server as 2 clients.
1571 let mut server = TestServer::start().await;
1572 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1573 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1574
1575 // Share a local worktree as client A
1576 let fs = Arc::new(FakeFs::new());
1577 fs.insert_tree(
1578 "/dir",
1579 json!({
1580 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1581 "a.txt": "a-contents",
1582 }),
1583 )
1584 .await;
1585
1586 let worktree_a = Worktree::open_local(
1587 client_a.clone(),
1588 "/dir".as_ref(),
1589 fs,
1590 lang_registry.clone(),
1591 &mut cx_a.to_async(),
1592 )
1593 .await
1594 .unwrap();
1595 worktree_a
1596 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1597 .await;
1598 let worktree_id = worktree_a
1599 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1600 .await
1601 .unwrap();
1602
1603 // Join that worktree as client B, and see that a guest has joined as client A.
1604 let worktree_b = Worktree::open_remote(
1605 client_b.clone(),
1606 worktree_id,
1607 lang_registry.clone(),
1608 &mut cx_b.to_async(),
1609 )
1610 .await
1611 .unwrap();
1612
1613 let buffer_b = worktree_b
1614 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1615 .await
1616 .unwrap();
1617 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1618
1619 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1620 buffer_b.read_with(&cx_b, |buf, _| {
1621 assert!(buf.is_dirty());
1622 assert!(!buf.has_conflict());
1623 });
1624
1625 buffer_b
1626 .update(&mut cx_b, |buf, cx| buf.save(cx))
1627 .unwrap()
1628 .await
1629 .unwrap();
1630 worktree_b
1631 .condition(&cx_b, |_, cx| {
1632 buffer_b.read(cx).file().unwrap().mtime != mtime
1633 })
1634 .await;
1635 buffer_b.read_with(&cx_b, |buf, _| {
1636 assert!(!buf.is_dirty());
1637 assert!(!buf.has_conflict());
1638 });
1639
1640 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1641 buffer_b.read_with(&cx_b, |buf, _| {
1642 assert!(buf.is_dirty());
1643 assert!(!buf.has_conflict());
1644 });
1645 }
1646
1647 #[gpui::test]
1648 async fn test_editing_while_guest_opens_buffer(
1649 mut cx_a: TestAppContext,
1650 mut cx_b: TestAppContext,
1651 ) {
1652 cx_a.foreground().forbid_parking();
1653 let lang_registry = Arc::new(LanguageRegistry::new());
1654
1655 // Connect to a server as 2 clients.
1656 let mut server = TestServer::start().await;
1657 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1658 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1659
1660 // Share a local worktree as client A
1661 let fs = Arc::new(FakeFs::new());
1662 fs.insert_tree(
1663 "/dir",
1664 json!({
1665 ".zed.toml": r#"collaborators = ["user_b"]"#,
1666 "a.txt": "a-contents",
1667 }),
1668 )
1669 .await;
1670 let worktree_a = Worktree::open_local(
1671 client_a.clone(),
1672 "/dir".as_ref(),
1673 fs,
1674 lang_registry.clone(),
1675 &mut cx_a.to_async(),
1676 )
1677 .await
1678 .unwrap();
1679 worktree_a
1680 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1681 .await;
1682 let worktree_id = worktree_a
1683 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1684 .await
1685 .unwrap();
1686
1687 // Join that worktree as client B, and see that a guest has joined as client A.
1688 let worktree_b = Worktree::open_remote(
1689 client_b.clone(),
1690 worktree_id,
1691 lang_registry.clone(),
1692 &mut cx_b.to_async(),
1693 )
1694 .await
1695 .unwrap();
1696
1697 let buffer_a = worktree_a
1698 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1699 .await
1700 .unwrap();
1701 let buffer_b = cx_b
1702 .background()
1703 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1704
1705 task::yield_now().await;
1706 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1707
1708 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1709 let buffer_b = buffer_b.await.unwrap();
1710 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1711 }
1712
1713 #[gpui::test]
1714 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1715 cx_a.foreground().forbid_parking();
1716 let lang_registry = Arc::new(LanguageRegistry::new());
1717
1718 // Connect to a server as 2 clients.
1719 let mut server = TestServer::start().await;
1720 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1721 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1722
1723 // Share a local worktree as client A
1724 let fs = Arc::new(FakeFs::new());
1725 fs.insert_tree(
1726 "/a",
1727 json!({
1728 ".zed.toml": r#"collaborators = ["user_b"]"#,
1729 "a.txt": "a-contents",
1730 "b.txt": "b-contents",
1731 }),
1732 )
1733 .await;
1734 let worktree_a = Worktree::open_local(
1735 client_a.clone(),
1736 "/a".as_ref(),
1737 fs,
1738 lang_registry.clone(),
1739 &mut cx_a.to_async(),
1740 )
1741 .await
1742 .unwrap();
1743 worktree_a
1744 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1745 .await;
1746 let worktree_id = worktree_a
1747 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1748 .await
1749 .unwrap();
1750
1751 // Join that worktree as client B, and see that a guest has joined as client A.
1752 let _worktree_b = Worktree::open_remote(
1753 client_b.clone(),
1754 worktree_id,
1755 lang_registry.clone(),
1756 &mut cx_b.to_async(),
1757 )
1758 .await
1759 .unwrap();
1760 worktree_a
1761 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1762 .await;
1763
1764 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1765 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1766 worktree_a
1767 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1768 .await;
1769 }
1770
1771 #[gpui::test]
1772 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1773 cx_a.foreground().forbid_parking();
1774
1775 // Connect to a server as 2 clients.
1776 let mut server = TestServer::start().await;
1777 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1778 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1779
1780 // Create an org that includes these 2 users.
1781 let db = &server.app_state.db;
1782 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1783 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1784 .await
1785 .unwrap();
1786 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1787 .await
1788 .unwrap();
1789
1790 // Create a channel that includes all the users.
1791 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1792 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1793 .await
1794 .unwrap();
1795 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1796 .await
1797 .unwrap();
1798 db.create_channel_message(
1799 channel_id,
1800 current_user_id(&user_store_b),
1801 "hello A, it's B.",
1802 OffsetDateTime::now_utc(),
1803 1,
1804 )
1805 .await
1806 .unwrap();
1807
1808 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1809 channels_a
1810 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1811 .await;
1812 channels_a.read_with(&cx_a, |list, _| {
1813 assert_eq!(
1814 list.available_channels().unwrap(),
1815 &[ChannelDetails {
1816 id: channel_id.to_proto(),
1817 name: "test-channel".to_string()
1818 }]
1819 )
1820 });
1821 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1822 this.get_channel(channel_id.to_proto(), cx).unwrap()
1823 });
1824 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1825 channel_a
1826 .condition(&cx_a, |channel, _| {
1827 channel_messages(channel)
1828 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1829 })
1830 .await;
1831
1832 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1833 channels_b
1834 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1835 .await;
1836 channels_b.read_with(&cx_b, |list, _| {
1837 assert_eq!(
1838 list.available_channels().unwrap(),
1839 &[ChannelDetails {
1840 id: channel_id.to_proto(),
1841 name: "test-channel".to_string()
1842 }]
1843 )
1844 });
1845
1846 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1847 this.get_channel(channel_id.to_proto(), cx).unwrap()
1848 });
1849 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1850 channel_b
1851 .condition(&cx_b, |channel, _| {
1852 channel_messages(channel)
1853 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1854 })
1855 .await;
1856
1857 channel_a
1858 .update(&mut cx_a, |channel, cx| {
1859 channel
1860 .send_message("oh, hi B.".to_string(), cx)
1861 .unwrap()
1862 .detach();
1863 let task = channel.send_message("sup".to_string(), cx).unwrap();
1864 assert_eq!(
1865 channel_messages(channel),
1866 &[
1867 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1868 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1869 ("user_a".to_string(), "sup".to_string(), true)
1870 ]
1871 );
1872 task
1873 })
1874 .await
1875 .unwrap();
1876
1877 channel_b
1878 .condition(&cx_b, |channel, _| {
1879 channel_messages(channel)
1880 == [
1881 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1882 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1883 ("user_a".to_string(), "sup".to_string(), false),
1884 ]
1885 })
1886 .await;
1887
1888 assert_eq!(
1889 server.state().await.channels[&channel_id]
1890 .connection_ids
1891 .len(),
1892 2
1893 );
1894 cx_b.update(|_| drop(channel_b));
1895 server
1896 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1897 .await;
1898
1899 cx_a.update(|_| drop(channel_a));
1900 server
1901 .condition(|state| !state.channels.contains_key(&channel_id))
1902 .await;
1903 }
1904
1905 #[gpui::test]
1906 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1907 cx_a.foreground().forbid_parking();
1908
1909 let mut server = TestServer::start().await;
1910 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1911
1912 let db = &server.app_state.db;
1913 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1914 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1915 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1916 .await
1917 .unwrap();
1918 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1919 .await
1920 .unwrap();
1921
1922 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1923 channels_a
1924 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1925 .await;
1926 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1927 this.get_channel(channel_id.to_proto(), cx).unwrap()
1928 });
1929
1930 // Messages aren't allowed to be too long.
1931 channel_a
1932 .update(&mut cx_a, |channel, cx| {
1933 let long_body = "this is long.\n".repeat(1024);
1934 channel.send_message(long_body, cx).unwrap()
1935 })
1936 .await
1937 .unwrap_err();
1938
1939 // Messages aren't allowed to be blank.
1940 channel_a.update(&mut cx_a, |channel, cx| {
1941 channel.send_message(String::new(), cx).unwrap_err()
1942 });
1943
1944 // Leading and trailing whitespace are trimmed.
1945 channel_a
1946 .update(&mut cx_a, |channel, cx| {
1947 channel
1948 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1949 .unwrap()
1950 })
1951 .await
1952 .unwrap();
1953 assert_eq!(
1954 db.get_channel_messages(channel_id, 10, None)
1955 .await
1956 .unwrap()
1957 .iter()
1958 .map(|m| &m.body)
1959 .collect::<Vec<_>>(),
1960 &["surrounded by whitespace"]
1961 );
1962 }
1963
1964 #[gpui::test]
1965 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1966 cx_a.foreground().forbid_parking();
1967 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1968
1969 // Connect to a server as 2 clients.
1970 let mut server = TestServer::start().await;
1971 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1972 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1973 let mut status_b = client_b.status();
1974
1975 // Create an org that includes these 2 users.
1976 let db = &server.app_state.db;
1977 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1978 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1979 .await
1980 .unwrap();
1981 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1982 .await
1983 .unwrap();
1984
1985 // Create a channel that includes all the users.
1986 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1987 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1988 .await
1989 .unwrap();
1990 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1991 .await
1992 .unwrap();
1993 db.create_channel_message(
1994 channel_id,
1995 current_user_id(&user_store_b),
1996 "hello A, it's B.",
1997 OffsetDateTime::now_utc(),
1998 2,
1999 )
2000 .await
2001 .unwrap();
2002
2003 let user_store_a =
2004 UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
2005 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
2006 channels_a
2007 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
2008 .await;
2009
2010 channels_a.read_with(&cx_a, |list, _| {
2011 assert_eq!(
2012 list.available_channels().unwrap(),
2013 &[ChannelDetails {
2014 id: channel_id.to_proto(),
2015 name: "test-channel".to_string()
2016 }]
2017 )
2018 });
2019 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2020 this.get_channel(channel_id.to_proto(), cx).unwrap()
2021 });
2022 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2023 channel_a
2024 .condition(&cx_a, |channel, _| {
2025 channel_messages(channel)
2026 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2027 })
2028 .await;
2029
2030 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2031 channels_b
2032 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2033 .await;
2034 channels_b.read_with(&cx_b, |list, _| {
2035 assert_eq!(
2036 list.available_channels().unwrap(),
2037 &[ChannelDetails {
2038 id: channel_id.to_proto(),
2039 name: "test-channel".to_string()
2040 }]
2041 )
2042 });
2043
2044 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2045 this.get_channel(channel_id.to_proto(), cx).unwrap()
2046 });
2047 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2048 channel_b
2049 .condition(&cx_b, |channel, _| {
2050 channel_messages(channel)
2051 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2052 })
2053 .await;
2054
2055 // Disconnect client B, ensuring we can still access its cached channel data.
2056 server.forbid_connections();
2057 server.disconnect_client(current_user_id(&user_store_b));
2058 while !matches!(
2059 status_b.recv().await,
2060 Some(rpc::Status::ReconnectionError { .. })
2061 ) {}
2062
2063 channels_b.read_with(&cx_b, |channels, _| {
2064 assert_eq!(
2065 channels.available_channels().unwrap(),
2066 [ChannelDetails {
2067 id: channel_id.to_proto(),
2068 name: "test-channel".to_string()
2069 }]
2070 )
2071 });
2072 channel_b.read_with(&cx_b, |channel, _| {
2073 assert_eq!(
2074 channel_messages(channel),
2075 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2076 )
2077 });
2078
2079 // Send a message from client B while it is disconnected.
2080 channel_b
2081 .update(&mut cx_b, |channel, cx| {
2082 let task = channel
2083 .send_message("can you see this?".to_string(), cx)
2084 .unwrap();
2085 assert_eq!(
2086 channel_messages(channel),
2087 &[
2088 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2089 ("user_b".to_string(), "can you see this?".to_string(), true)
2090 ]
2091 );
2092 task
2093 })
2094 .await
2095 .unwrap_err();
2096
2097 // Send a message from client A while B is disconnected.
2098 channel_a
2099 .update(&mut cx_a, |channel, cx| {
2100 channel
2101 .send_message("oh, hi B.".to_string(), cx)
2102 .unwrap()
2103 .detach();
2104 let task = channel.send_message("sup".to_string(), cx).unwrap();
2105 assert_eq!(
2106 channel_messages(channel),
2107 &[
2108 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2109 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2110 ("user_a".to_string(), "sup".to_string(), true)
2111 ]
2112 );
2113 task
2114 })
2115 .await
2116 .unwrap();
2117
2118 // Give client B a chance to reconnect.
2119 server.allow_connections();
2120 cx_b.foreground().advance_clock(Duration::from_secs(10));
2121
2122 // Verify that B sees the new messages upon reconnection, as well as the message client B
2123 // sent while offline.
2124 channel_b
2125 .condition(&cx_b, |channel, _| {
2126 channel_messages(channel)
2127 == [
2128 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2129 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2130 ("user_a".to_string(), "sup".to_string(), false),
2131 ("user_b".to_string(), "can you see this?".to_string(), false),
2132 ]
2133 })
2134 .await;
2135
2136 // Ensure client A and B can communicate normally after reconnection.
2137 channel_a
2138 .update(&mut cx_a, |channel, cx| {
2139 channel.send_message("you online?".to_string(), cx).unwrap()
2140 })
2141 .await
2142 .unwrap();
2143 channel_b
2144 .condition(&cx_b, |channel, _| {
2145 channel_messages(channel)
2146 == [
2147 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2148 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2149 ("user_a".to_string(), "sup".to_string(), false),
2150 ("user_b".to_string(), "can you see this?".to_string(), false),
2151 ("user_a".to_string(), "you online?".to_string(), false),
2152 ]
2153 })
2154 .await;
2155
2156 channel_b
2157 .update(&mut cx_b, |channel, cx| {
2158 channel.send_message("yep".to_string(), cx).unwrap()
2159 })
2160 .await
2161 .unwrap();
2162 channel_a
2163 .condition(&cx_a, |channel, _| {
2164 channel_messages(channel)
2165 == [
2166 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2167 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2168 ("user_a".to_string(), "sup".to_string(), false),
2169 ("user_b".to_string(), "can you see this?".to_string(), false),
2170 ("user_a".to_string(), "you online?".to_string(), false),
2171 ("user_b".to_string(), "yep".to_string(), false),
2172 ]
2173 })
2174 .await;
2175 }
2176
2177 struct TestServer {
2178 peer: Arc<Peer>,
2179 app_state: Arc<AppState>,
2180 server: Arc<Server>,
2181 notifications: mpsc::Receiver<()>,
2182 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2183 forbid_connections: Arc<AtomicBool>,
2184 _test_db: TestDb,
2185 }
2186
2187 impl TestServer {
2188 async fn start() -> Self {
2189 let test_db = TestDb::new();
2190 let app_state = Self::build_app_state(&test_db).await;
2191 let peer = Peer::new();
2192 let notifications = mpsc::channel(128);
2193 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2194 Self {
2195 peer,
2196 app_state,
2197 server,
2198 notifications: notifications.1,
2199 connection_killers: Default::default(),
2200 forbid_connections: Default::default(),
2201 _test_db: test_db,
2202 }
2203 }
2204
2205 async fn create_client(
2206 &mut self,
2207 cx: &mut TestAppContext,
2208 name: &str,
2209 ) -> (Arc<Client>, Arc<UserStore>) {
2210 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2211 let client_name = name.to_string();
2212 let mut client = Client::new();
2213 let server = self.server.clone();
2214 let connection_killers = self.connection_killers.clone();
2215 let forbid_connections = self.forbid_connections.clone();
2216 Arc::get_mut(&mut client)
2217 .unwrap()
2218 .override_authenticate(move |cx| {
2219 cx.spawn(|_| async move {
2220 let access_token = "the-token".to_string();
2221 Ok(Credentials {
2222 user_id: user_id.0 as u64,
2223 access_token,
2224 })
2225 })
2226 })
2227 .override_establish_connection(move |credentials, cx| {
2228 assert_eq!(credentials.user_id, user_id.0 as u64);
2229 assert_eq!(credentials.access_token, "the-token");
2230
2231 let server = server.clone();
2232 let connection_killers = connection_killers.clone();
2233 let forbid_connections = forbid_connections.clone();
2234 let client_name = client_name.clone();
2235 cx.spawn(move |cx| async move {
2236 if forbid_connections.load(SeqCst) {
2237 Err(EstablishConnectionError::other(anyhow!(
2238 "server is forbidding connections"
2239 )))
2240 } else {
2241 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2242 connection_killers.lock().insert(user_id, kill_conn);
2243 cx.background()
2244 .spawn(server.handle_connection(server_conn, client_name, user_id))
2245 .detach();
2246 Ok(client_conn)
2247 }
2248 })
2249 });
2250
2251 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2252 client
2253 .authenticate_and_connect(&cx.to_async())
2254 .await
2255 .unwrap();
2256
2257 let user_store = UserStore::new(client.clone(), http, &cx.background());
2258 let mut authed_user = user_store.watch_current_user();
2259 while authed_user.recv().await.unwrap().is_none() {}
2260
2261 (client, user_store)
2262 }
2263
2264 fn disconnect_client(&self, user_id: UserId) {
2265 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2266 let _ = kill_conn.try_send(Some(()));
2267 }
2268 }
2269
2270 fn forbid_connections(&self) {
2271 self.forbid_connections.store(true, SeqCst);
2272 }
2273
2274 fn allow_connections(&self) {
2275 self.forbid_connections.store(false, SeqCst);
2276 }
2277
2278 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2279 let mut config = Config::default();
2280 config.session_secret = "a".repeat(32);
2281 config.database_url = test_db.url.clone();
2282 let github_client = github::AppClient::test();
2283 Arc::new(AppState {
2284 db: test_db.db().clone(),
2285 handlebars: Default::default(),
2286 auth_client: auth::build_client("", ""),
2287 repo_client: github::RepoClient::test(&github_client),
2288 github_client,
2289 config,
2290 })
2291 }
2292
2293 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2294 self.server.state.read().await
2295 }
2296
2297 async fn condition<F>(&mut self, mut predicate: F)
2298 where
2299 F: FnMut(&ServerState) -> bool,
2300 {
2301 async_std::future::timeout(Duration::from_millis(500), async {
2302 while !(predicate)(&*self.server.state.read().await) {
2303 self.notifications.recv().await;
2304 }
2305 })
2306 .await
2307 .expect("condition timed out");
2308 }
2309 }
2310
2311 impl Drop for TestServer {
2312 fn drop(&mut self) {
2313 task::block_on(self.peer.reset());
2314 }
2315 }
2316
2317 fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2318 UserId::from_proto(user_store.current_user().unwrap().id)
2319 }
2320
2321 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2322 channel
2323 .messages()
2324 .cursor::<(), ()>()
2325 .map(|m| {
2326 (
2327 m.sender.github_login.clone(),
2328 m.body.clone(),
2329 m.is_pending(),
2330 )
2331 })
2332 .collect()
2333 }
2334
2335 struct EmptyView;
2336
2337 impl gpui::Entity for EmptyView {
2338 type Event = ();
2339 }
2340
2341 impl gpui::View for EmptyView {
2342 fn ui_name() -> &'static str {
2343 "empty view"
2344 }
2345
2346 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2347 gpui::Element::boxed(gpui::elements::Empty)
2348 }
2349 }
2350}