1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
29 Connection, ConnectionId, Peer, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34type MessageHandler = Box<
35 dyn Send
36 + Sync
37 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
38>;
39
40pub struct Server {
41 peer: Arc<Peer>,
42 state: RwLock<ServerState>,
43 app_state: Arc<AppState>,
44 handlers: HashMap<TypeId, MessageHandler>,
45 notifications: Option<mpsc::Sender<()>>,
46}
47
48#[derive(Default)]
49struct ServerState {
50 connections: HashMap<ConnectionId, ConnectionState>,
51 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
52 pub worktrees: HashMap<u64, Worktree>,
53 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
54 channels: HashMap<ChannelId, Channel>,
55 next_worktree_id: u64,
56}
57
58struct ConnectionState {
59 user_id: UserId,
60 worktrees: HashSet<u64>,
61 channels: HashSet<ChannelId>,
62}
63
64struct Worktree {
65 host_connection_id: ConnectionId,
66 collaborator_user_ids: Vec<UserId>,
67 root_name: String,
68 share: Option<WorktreeShare>,
69}
70
71struct WorktreeShare {
72 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
73 active_replica_ids: HashSet<ReplicaId>,
74 entries: HashMap<u64, proto::Entry>,
75}
76
77#[derive(Default)]
78struct Channel {
79 connection_ids: HashSet<ConnectionId>,
80}
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84
85impl Server {
86 pub fn new(
87 app_state: Arc<AppState>,
88 peer: Arc<Peer>,
89 notifications: Option<mpsc::Sender<()>>,
90 ) -> Arc<Self> {
91 let mut server = Self {
92 peer,
93 app_state,
94 state: Default::default(),
95 handlers: Default::default(),
96 notifications,
97 };
98
99 server
100 .add_handler(Server::ping)
101 .add_handler(Server::open_worktree)
102 .add_handler(Server::handle_close_worktree)
103 .add_handler(Server::share_worktree)
104 .add_handler(Server::unshare_worktree)
105 .add_handler(Server::join_worktree)
106 .add_handler(Server::update_worktree)
107 .add_handler(Server::open_buffer)
108 .add_handler(Server::close_buffer)
109 .add_handler(Server::update_buffer)
110 .add_handler(Server::buffer_saved)
111 .add_handler(Server::save_buffer)
112 .add_handler(Server::get_channels)
113 .add_handler(Server::get_users)
114 .add_handler(Server::join_channel)
115 .add_handler(Server::leave_channel)
116 .add_handler(Server::send_channel_message)
117 .add_handler(Server::get_channel_messages);
118
119 Arc::new(server)
120 }
121
122 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
123 where
124 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
125 Fut: 'static + Send + Future<Output = tide::Result<()>>,
126 M: EnvelopedMessage,
127 {
128 let prev_handler = self.handlers.insert(
129 TypeId::of::<M>(),
130 Box::new(move |server, envelope| {
131 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
132 (handler)(server, *envelope).boxed()
133 }),
134 );
135 if prev_handler.is_some() {
136 panic!("registered a handler for the same message twice");
137 }
138 self
139 }
140
141 pub fn handle_connection(
142 self: &Arc<Self>,
143 connection: Connection,
144 addr: String,
145 user_id: UserId,
146 ) -> impl Future<Output = ()> {
147 let this = self.clone();
148 async move {
149 let (connection_id, handle_io, mut incoming_rx) =
150 this.peer.add_connection(connection).await;
151 this.add_connection(connection_id, user_id).await;
152
153 let handle_io = handle_io.fuse();
154 futures::pin_mut!(handle_io);
155 loop {
156 let next_message = incoming_rx.recv().fuse();
157 futures::pin_mut!(next_message);
158 futures::select_biased! {
159 message = next_message => {
160 if let Some(message) = message {
161 let start_time = Instant::now();
162 log::info!("RPC message received: {}", message.payload_type_name());
163 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
164 if let Err(err) = (handler)(this.clone(), message).await {
165 log::error!("error handling message: {:?}", err);
166 } else {
167 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
168 }
169
170 if let Some(mut notifications) = this.notifications.clone() {
171 let _ = notifications.send(()).await;
172 }
173 } else {
174 log::warn!("unhandled message: {}", message.payload_type_name());
175 }
176 } else {
177 log::info!("rpc connection closed {:?}", addr);
178 break;
179 }
180 }
181 handle_io = handle_io => {
182 if let Err(err) = handle_io {
183 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
184 }
185 break;
186 }
187 }
188 }
189
190 if let Err(err) = this.sign_out(connection_id).await {
191 log::error!("error signing out connection {:?} - {:?}", addr, err);
192 }
193 }
194 }
195
196 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
197 self.peer.disconnect(connection_id).await;
198 self.remove_connection(connection_id).await?;
199 Ok(())
200 }
201
202 // Add a new connection associated with a given user.
203 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
204 let mut state = self.state.write().await;
205 state.connections.insert(
206 connection_id,
207 ConnectionState {
208 user_id,
209 worktrees: Default::default(),
210 channels: Default::default(),
211 },
212 );
213 state
214 .connections_by_user_id
215 .entry(user_id)
216 .or_default()
217 .insert(connection_id);
218 }
219
220 // Remove the given connection and its association with any worktrees.
221 async fn remove_connection(
222 self: &Arc<Server>,
223 connection_id: ConnectionId,
224 ) -> tide::Result<()> {
225 let mut worktree_ids = Vec::new();
226 let mut state = self.state.write().await;
227 if let Some(connection) = state.connections.remove(&connection_id) {
228 worktree_ids = connection.worktrees.into_iter().collect();
229
230 for channel_id in connection.channels {
231 if let Some(channel) = state.channels.get_mut(&channel_id) {
232 channel.connection_ids.remove(&connection_id);
233 }
234 }
235
236 let user_connections = state
237 .connections_by_user_id
238 .get_mut(&connection.user_id)
239 .unwrap();
240 user_connections.remove(&connection_id);
241 if user_connections.is_empty() {
242 state.connections_by_user_id.remove(&connection.user_id);
243 }
244 }
245
246 for worktree_id in worktree_ids {
247 self.close_worktree(worktree_id, connection_id).await?;
248 }
249
250 Ok(())
251 }
252
253 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
254 self.peer.respond(request.receipt(), proto::Ack {}).await?;
255 Ok(())
256 }
257
258 async fn open_worktree(
259 self: Arc<Server>,
260 request: TypedEnvelope<proto::OpenWorktree>,
261 ) -> tide::Result<()> {
262 let receipt = request.receipt();
263 let host_user_id = self
264 .state
265 .read()
266 .await
267 .user_id_for_connection(request.sender_id)?;
268
269 let mut collaborator_user_ids = Vec::new();
270 for github_login in request.payload.collaborator_logins {
271 match self.app_state.db.create_user(&github_login, false).await {
272 Ok(collaborator_user_id) => {
273 if collaborator_user_id != host_user_id {
274 collaborator_user_ids.push(collaborator_user_id);
275 }
276 }
277 Err(err) => {
278 let message = err.to_string();
279 self.peer
280 .respond_with_error(receipt, proto::Error { message })
281 .await?;
282 return Ok(());
283 }
284 }
285 }
286
287 let worktree_id;
288 let mut user_ids;
289 {
290 let mut state = self.state.write().await;
291 worktree_id = state.add_worktree(Worktree {
292 host_connection_id: request.sender_id,
293 collaborator_user_ids: collaborator_user_ids.clone(),
294 root_name: request.payload.root_name,
295 share: None,
296 });
297 user_ids = collaborator_user_ids;
298 user_ids.push(host_user_id);
299 }
300
301 self.peer
302 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
303 .await?;
304 self.update_collaborators_for_users(&user_ids).await?;
305
306 Ok(())
307 }
308
309 async fn share_worktree(
310 self: Arc<Server>,
311 mut request: TypedEnvelope<proto::ShareWorktree>,
312 ) -> tide::Result<()> {
313 let host_user_id = self
314 .state
315 .read()
316 .await
317 .user_id_for_connection(request.sender_id)?;
318 let worktree = request
319 .payload
320 .worktree
321 .as_mut()
322 .ok_or_else(|| anyhow!("missing worktree"))?;
323 let entries = mem::take(&mut worktree.entries)
324 .into_iter()
325 .map(|entry| (entry.id, entry))
326 .collect();
327
328 let mut state = self.state.write().await;
329 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
330 worktree.share = Some(WorktreeShare {
331 guest_connection_ids: Default::default(),
332 active_replica_ids: Default::default(),
333 entries,
334 });
335
336 let mut user_ids = worktree.collaborator_user_ids.clone();
337 user_ids.push(host_user_id);
338
339 drop(state);
340 self.peer
341 .respond(request.receipt(), proto::ShareWorktreeResponse {})
342 .await?;
343 self.update_collaborators_for_users(&user_ids).await?;
344 } else {
345 self.peer
346 .respond_with_error(
347 request.receipt(),
348 proto::Error {
349 message: "no such worktree".to_string(),
350 },
351 )
352 .await?;
353 }
354 Ok(())
355 }
356
357 async fn unshare_worktree(
358 self: Arc<Server>,
359 request: TypedEnvelope<proto::UnshareWorktree>,
360 ) -> tide::Result<()> {
361 let worktree_id = request.payload.worktree_id;
362 let host_user_id = self
363 .state
364 .read()
365 .await
366 .user_id_for_connection(request.sender_id)?;
367
368 let connection_ids;
369 let mut user_ids;
370 {
371 let mut state = self.state.write().await;
372 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
373 if worktree.host_connection_id != request.sender_id {
374 return Err(anyhow!("no such worktree"))?;
375 }
376
377 connection_ids = worktree.connection_ids();
378 user_ids = worktree.collaborator_user_ids.clone();
379 user_ids.push(host_user_id);
380 worktree.share.take();
381 for connection_id in &connection_ids {
382 if let Some(connection) = state.connections.get_mut(connection_id) {
383 connection.worktrees.remove(&worktree_id);
384 }
385 }
386 }
387
388 broadcast(request.sender_id, connection_ids, |conn_id| {
389 self.peer
390 .send(conn_id, proto::UnshareWorktree { worktree_id })
391 })
392 .await?;
393 self.update_collaborators_for_users(&user_ids).await?;
394
395 Ok(())
396 }
397
398 async fn join_worktree(
399 self: Arc<Server>,
400 request: TypedEnvelope<proto::JoinWorktree>,
401 ) -> tide::Result<()> {
402 let worktree_id = request.payload.worktree_id;
403 let user_id = self
404 .state
405 .read()
406 .await
407 .user_id_for_connection(request.sender_id)?;
408
409 let response;
410 let connection_ids;
411 let mut user_ids;
412 let mut state = self.state.write().await;
413 match state.join_worktree(request.sender_id, user_id, worktree_id) {
414 Ok((peer_replica_id, worktree)) => {
415 let share = worktree.share()?;
416 let peer_count = share.guest_connection_ids.len();
417 let mut peers = Vec::with_capacity(peer_count);
418 peers.push(proto::Peer {
419 peer_id: worktree.host_connection_id.0,
420 replica_id: 0,
421 });
422 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
423 if *peer_conn_id != request.sender_id {
424 peers.push(proto::Peer {
425 peer_id: peer_conn_id.0,
426 replica_id: *peer_replica_id as u32,
427 });
428 }
429 }
430 response = proto::JoinWorktreeResponse {
431 worktree: Some(proto::Worktree {
432 id: worktree_id,
433 root_name: worktree.root_name.clone(),
434 entries: share.entries.values().cloned().collect(),
435 }),
436 replica_id: peer_replica_id as u32,
437 peers,
438 };
439
440 let host_connection_id = worktree.host_connection_id;
441 connection_ids = worktree.connection_ids();
442 user_ids = worktree.collaborator_user_ids.clone();
443 user_ids.push(state.user_id_for_connection(host_connection_id)?);
444 }
445 Err(error) => {
446 self.peer
447 .respond_with_error(
448 request.receipt(),
449 proto::Error {
450 message: error.to_string(),
451 },
452 )
453 .await?;
454 return Ok(());
455 }
456 }
457
458 broadcast(request.sender_id, connection_ids, |conn_id| {
459 self.peer.send(
460 conn_id,
461 proto::AddPeer {
462 worktree_id,
463 peer: Some(proto::Peer {
464 peer_id: request.sender_id.0,
465 replica_id: response.replica_id,
466 }),
467 },
468 )
469 })
470 .await?;
471 self.peer.respond(request.receipt(), response).await?;
472 self.update_collaborators_for_users(&user_ids).await?;
473
474 Ok(())
475 }
476
477 async fn handle_close_worktree(
478 self: Arc<Server>,
479 request: TypedEnvelope<proto::CloseWorktree>,
480 ) -> tide::Result<()> {
481 self.close_worktree(request.payload.worktree_id, request.sender_id)
482 .await
483 }
484
485 async fn close_worktree(
486 self: &Arc<Server>,
487 worktree_id: u64,
488 conn_id: ConnectionId,
489 ) -> tide::Result<()> {
490 let connection_ids;
491 let mut user_ids;
492
493 let mut is_host = false;
494 let mut is_guest = false;
495 {
496 let mut state = self.state.write().await;
497 let worktree = state.write_worktree(worktree_id, conn_id)?;
498 let host_connection_id = worktree.host_connection_id;
499 connection_ids = worktree.connection_ids();
500 user_ids = worktree.collaborator_user_ids.clone();
501
502 if worktree.host_connection_id == conn_id {
503 is_host = true;
504 state.remove_worktree(worktree_id);
505 } else {
506 let share = worktree.share_mut()?;
507 if let Some(replica_id) = share.guest_connection_ids.remove(&conn_id) {
508 is_guest = true;
509 share.active_replica_ids.remove(&replica_id);
510 }
511 }
512
513 user_ids.push(state.user_id_for_connection(host_connection_id)?);
514 }
515
516 if is_host {
517 broadcast(conn_id, connection_ids, |conn_id| {
518 self.peer
519 .send(conn_id, proto::UnshareWorktree { worktree_id })
520 })
521 .await?;
522 } else if is_guest {
523 broadcast(conn_id, connection_ids, |conn_id| {
524 self.peer.send(
525 conn_id,
526 proto::RemovePeer {
527 worktree_id,
528 peer_id: conn_id.0,
529 },
530 )
531 })
532 .await?
533 }
534 self.update_collaborators_for_users(&user_ids).await?;
535 Ok(())
536 }
537
538 async fn update_worktree(
539 self: Arc<Server>,
540 request: TypedEnvelope<proto::UpdateWorktree>,
541 ) -> tide::Result<()> {
542 {
543 let mut state = self.state.write().await;
544 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
545 let share = worktree.share_mut()?;
546
547 for entry_id in &request.payload.removed_entries {
548 share.entries.remove(&entry_id);
549 }
550
551 for entry in &request.payload.updated_entries {
552 share.entries.insert(entry.id, entry.clone());
553 }
554 }
555
556 self.broadcast_in_worktree(request.payload.worktree_id, &request)
557 .await?;
558 Ok(())
559 }
560
561 async fn open_buffer(
562 self: Arc<Server>,
563 request: TypedEnvelope<proto::OpenBuffer>,
564 ) -> tide::Result<()> {
565 let receipt = request.receipt();
566 let worktree_id = request.payload.worktree_id;
567 let host_connection_id = self
568 .state
569 .read()
570 .await
571 .read_worktree(worktree_id, request.sender_id)?
572 .host_connection_id;
573
574 let response = self
575 .peer
576 .forward_request(request.sender_id, host_connection_id, request.payload)
577 .await?;
578 self.peer.respond(receipt, response).await?;
579 Ok(())
580 }
581
582 async fn close_buffer(
583 self: Arc<Server>,
584 request: TypedEnvelope<proto::CloseBuffer>,
585 ) -> tide::Result<()> {
586 let host_connection_id = self
587 .state
588 .read()
589 .await
590 .read_worktree(request.payload.worktree_id, request.sender_id)?
591 .host_connection_id;
592
593 self.peer
594 .forward_send(request.sender_id, host_connection_id, request.payload)
595 .await?;
596
597 Ok(())
598 }
599
600 async fn save_buffer(
601 self: Arc<Server>,
602 request: TypedEnvelope<proto::SaveBuffer>,
603 ) -> tide::Result<()> {
604 let host;
605 let guests;
606 {
607 let state = self.state.read().await;
608 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
609 host = worktree.host_connection_id;
610 guests = worktree
611 .share()?
612 .guest_connection_ids
613 .keys()
614 .copied()
615 .collect::<Vec<_>>();
616 }
617
618 let sender = request.sender_id;
619 let receipt = request.receipt();
620 let response = self
621 .peer
622 .forward_request(sender, host, request.payload.clone())
623 .await?;
624
625 broadcast(host, guests, |conn_id| {
626 let response = response.clone();
627 let peer = &self.peer;
628 async move {
629 if conn_id == sender {
630 peer.respond(receipt, response).await
631 } else {
632 peer.forward_send(host, conn_id, response).await
633 }
634 }
635 })
636 .await?;
637
638 Ok(())
639 }
640
641 async fn update_buffer(
642 self: Arc<Server>,
643 request: TypedEnvelope<proto::UpdateBuffer>,
644 ) -> tide::Result<()> {
645 self.broadcast_in_worktree(request.payload.worktree_id, &request)
646 .await?;
647 self.peer.respond(request.receipt(), proto::Ack {}).await?;
648 Ok(())
649 }
650
651 async fn buffer_saved(
652 self: Arc<Server>,
653 request: TypedEnvelope<proto::BufferSaved>,
654 ) -> tide::Result<()> {
655 self.broadcast_in_worktree(request.payload.worktree_id, &request)
656 .await
657 }
658
659 async fn get_channels(
660 self: Arc<Server>,
661 request: TypedEnvelope<proto::GetChannels>,
662 ) -> tide::Result<()> {
663 let user_id = self
664 .state
665 .read()
666 .await
667 .user_id_for_connection(request.sender_id)?;
668 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
669 self.peer
670 .respond(
671 request.receipt(),
672 proto::GetChannelsResponse {
673 channels: channels
674 .into_iter()
675 .map(|chan| proto::Channel {
676 id: chan.id.to_proto(),
677 name: chan.name,
678 })
679 .collect(),
680 },
681 )
682 .await?;
683 Ok(())
684 }
685
686 async fn get_users(
687 self: Arc<Server>,
688 request: TypedEnvelope<proto::GetUsers>,
689 ) -> tide::Result<()> {
690 let user_id = self
691 .state
692 .read()
693 .await
694 .user_id_for_connection(request.sender_id)?;
695 let receipt = request.receipt();
696 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
697 let users = self
698 .app_state
699 .db
700 .get_users_by_ids(user_id, user_ids)
701 .await?
702 .into_iter()
703 .map(|user| proto::User {
704 id: user.id.to_proto(),
705 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
706 github_login: user.github_login,
707 })
708 .collect();
709 self.peer
710 .respond(receipt, proto::GetUsersResponse { users })
711 .await?;
712 Ok(())
713 }
714
715 async fn update_collaborators_for_users<'a>(
716 self: &Arc<Server>,
717 user_ids: impl IntoIterator<Item = &'a UserId>,
718 ) -> tide::Result<()> {
719 let mut send_futures = Vec::new();
720
721 let state = self.state.read().await;
722 for user_id in user_ids {
723 let mut collaborators = HashMap::new();
724 for worktree_id in state
725 .visible_worktrees_by_user_id
726 .get(&user_id)
727 .unwrap_or(&HashSet::new())
728 {
729 let worktree = &state.worktrees[worktree_id];
730
731 let mut participants = HashSet::new();
732 if let Ok(share) = worktree.share() {
733 for guest_connection_id in share.guest_connection_ids.keys() {
734 let user_id = state.user_id_for_connection(*guest_connection_id)?;
735 participants.insert(user_id.to_proto());
736 }
737 }
738
739 let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
740 let host =
741 collaborators
742 .entry(host_user_id)
743 .or_insert_with(|| proto::Collaborator {
744 user_id: host_user_id.to_proto(),
745 worktrees: Vec::new(),
746 });
747 host.worktrees.push(proto::WorktreeMetadata {
748 root_name: worktree.root_name.clone(),
749 is_shared: worktree.share().is_ok(),
750 participants: participants.into_iter().collect(),
751 });
752 }
753
754 let collaborators = collaborators.into_values().collect::<Vec<_>>();
755 for connection_id in state.user_connection_ids(*user_id) {
756 send_futures.push(self.peer.send(
757 connection_id,
758 proto::UpdateCollaborators {
759 collaborators: collaborators.clone(),
760 },
761 ));
762 }
763 }
764
765 drop(state);
766 futures::future::try_join_all(send_futures).await?;
767
768 Ok(())
769 }
770
771 async fn join_channel(
772 self: Arc<Self>,
773 request: TypedEnvelope<proto::JoinChannel>,
774 ) -> tide::Result<()> {
775 let user_id = self
776 .state
777 .read()
778 .await
779 .user_id_for_connection(request.sender_id)?;
780 let channel_id = ChannelId::from_proto(request.payload.channel_id);
781 if !self
782 .app_state
783 .db
784 .can_user_access_channel(user_id, channel_id)
785 .await?
786 {
787 Err(anyhow!("access denied"))?;
788 }
789
790 self.state
791 .write()
792 .await
793 .join_channel(request.sender_id, channel_id);
794 let messages = self
795 .app_state
796 .db
797 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
798 .await?
799 .into_iter()
800 .map(|msg| proto::ChannelMessage {
801 id: msg.id.to_proto(),
802 body: msg.body,
803 timestamp: msg.sent_at.unix_timestamp() as u64,
804 sender_id: msg.sender_id.to_proto(),
805 nonce: Some(msg.nonce.as_u128().into()),
806 })
807 .collect::<Vec<_>>();
808 self.peer
809 .respond(
810 request.receipt(),
811 proto::JoinChannelResponse {
812 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
813 messages,
814 },
815 )
816 .await?;
817 Ok(())
818 }
819
820 async fn leave_channel(
821 self: Arc<Self>,
822 request: TypedEnvelope<proto::LeaveChannel>,
823 ) -> tide::Result<()> {
824 let user_id = self
825 .state
826 .read()
827 .await
828 .user_id_for_connection(request.sender_id)?;
829 let channel_id = ChannelId::from_proto(request.payload.channel_id);
830 if !self
831 .app_state
832 .db
833 .can_user_access_channel(user_id, channel_id)
834 .await?
835 {
836 Err(anyhow!("access denied"))?;
837 }
838
839 self.state
840 .write()
841 .await
842 .leave_channel(request.sender_id, channel_id);
843
844 Ok(())
845 }
846
847 async fn send_channel_message(
848 self: Arc<Self>,
849 request: TypedEnvelope<proto::SendChannelMessage>,
850 ) -> tide::Result<()> {
851 let receipt = request.receipt();
852 let channel_id = ChannelId::from_proto(request.payload.channel_id);
853 let user_id;
854 let connection_ids;
855 {
856 let state = self.state.read().await;
857 user_id = state.user_id_for_connection(request.sender_id)?;
858 if let Some(channel) = state.channels.get(&channel_id) {
859 connection_ids = channel.connection_ids();
860 } else {
861 return Ok(());
862 }
863 }
864
865 // Validate the message body.
866 let body = request.payload.body.trim().to_string();
867 if body.len() > MAX_MESSAGE_LEN {
868 self.peer
869 .respond_with_error(
870 receipt,
871 proto::Error {
872 message: "message is too long".to_string(),
873 },
874 )
875 .await?;
876 return Ok(());
877 }
878 if body.is_empty() {
879 self.peer
880 .respond_with_error(
881 receipt,
882 proto::Error {
883 message: "message can't be blank".to_string(),
884 },
885 )
886 .await?;
887 return Ok(());
888 }
889
890 let timestamp = OffsetDateTime::now_utc();
891 let nonce = if let Some(nonce) = request.payload.nonce {
892 nonce
893 } else {
894 self.peer
895 .respond_with_error(
896 receipt,
897 proto::Error {
898 message: "nonce can't be blank".to_string(),
899 },
900 )
901 .await?;
902 return Ok(());
903 };
904
905 let message_id = self
906 .app_state
907 .db
908 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
909 .await?
910 .to_proto();
911 let message = proto::ChannelMessage {
912 sender_id: user_id.to_proto(),
913 id: message_id,
914 body,
915 timestamp: timestamp.unix_timestamp() as u64,
916 nonce: Some(nonce),
917 };
918 broadcast(request.sender_id, connection_ids, |conn_id| {
919 self.peer.send(
920 conn_id,
921 proto::ChannelMessageSent {
922 channel_id: channel_id.to_proto(),
923 message: Some(message.clone()),
924 },
925 )
926 })
927 .await?;
928 self.peer
929 .respond(
930 receipt,
931 proto::SendChannelMessageResponse {
932 message: Some(message),
933 },
934 )
935 .await?;
936 Ok(())
937 }
938
939 async fn get_channel_messages(
940 self: Arc<Self>,
941 request: TypedEnvelope<proto::GetChannelMessages>,
942 ) -> tide::Result<()> {
943 let user_id = self
944 .state
945 .read()
946 .await
947 .user_id_for_connection(request.sender_id)?;
948 let channel_id = ChannelId::from_proto(request.payload.channel_id);
949 if !self
950 .app_state
951 .db
952 .can_user_access_channel(user_id, channel_id)
953 .await?
954 {
955 Err(anyhow!("access denied"))?;
956 }
957
958 let messages = self
959 .app_state
960 .db
961 .get_channel_messages(
962 channel_id,
963 MESSAGE_COUNT_PER_PAGE,
964 Some(MessageId::from_proto(request.payload.before_message_id)),
965 )
966 .await?
967 .into_iter()
968 .map(|msg| proto::ChannelMessage {
969 id: msg.id.to_proto(),
970 body: msg.body,
971 timestamp: msg.sent_at.unix_timestamp() as u64,
972 sender_id: msg.sender_id.to_proto(),
973 nonce: Some(msg.nonce.as_u128().into()),
974 })
975 .collect::<Vec<_>>();
976 self.peer
977 .respond(
978 request.receipt(),
979 proto::GetChannelMessagesResponse {
980 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
981 messages,
982 },
983 )
984 .await?;
985 Ok(())
986 }
987
988 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
989 &self,
990 worktree_id: u64,
991 message: &TypedEnvelope<T>,
992 ) -> tide::Result<()> {
993 let connection_ids = self
994 .state
995 .read()
996 .await
997 .read_worktree(worktree_id, message.sender_id)?
998 .connection_ids();
999
1000 broadcast(message.sender_id, connection_ids, |conn_id| {
1001 self.peer
1002 .forward_send(message.sender_id, conn_id, message.payload.clone())
1003 })
1004 .await?;
1005
1006 Ok(())
1007 }
1008}
1009
1010pub async fn broadcast<F, T>(
1011 sender_id: ConnectionId,
1012 receiver_ids: Vec<ConnectionId>,
1013 mut f: F,
1014) -> anyhow::Result<()>
1015where
1016 F: FnMut(ConnectionId) -> T,
1017 T: Future<Output = anyhow::Result<()>>,
1018{
1019 let futures = receiver_ids
1020 .into_iter()
1021 .filter(|id| *id != sender_id)
1022 .map(|id| f(id));
1023 futures::future::try_join_all(futures).await?;
1024 Ok(())
1025}
1026
1027impl ServerState {
1028 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1029 if let Some(connection) = self.connections.get_mut(&connection_id) {
1030 connection.channels.insert(channel_id);
1031 self.channels
1032 .entry(channel_id)
1033 .or_default()
1034 .connection_ids
1035 .insert(connection_id);
1036 }
1037 }
1038
1039 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1040 if let Some(connection) = self.connections.get_mut(&connection_id) {
1041 connection.channels.remove(&channel_id);
1042 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1043 entry.get_mut().connection_ids.remove(&connection_id);
1044 if entry.get_mut().connection_ids.is_empty() {
1045 entry.remove();
1046 }
1047 }
1048 }
1049 }
1050
1051 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1052 Ok(self
1053 .connections
1054 .get(&connection_id)
1055 .ok_or_else(|| anyhow!("unknown connection"))?
1056 .user_id)
1057 }
1058
1059 fn user_connection_ids<'a>(
1060 &'a self,
1061 user_id: UserId,
1062 ) -> impl 'a + Iterator<Item = ConnectionId> {
1063 self.connections_by_user_id
1064 .get(&user_id)
1065 .into_iter()
1066 .flatten()
1067 .copied()
1068 }
1069
1070 // Add the given connection as a guest of the given worktree
1071 fn join_worktree(
1072 &mut self,
1073 connection_id: ConnectionId,
1074 user_id: UserId,
1075 worktree_id: u64,
1076 ) -> tide::Result<(ReplicaId, &Worktree)> {
1077 let connection = self
1078 .connections
1079 .get_mut(&connection_id)
1080 .ok_or_else(|| anyhow!("no such connection"))?;
1081 let worktree = self
1082 .worktrees
1083 .get_mut(&worktree_id)
1084 .ok_or_else(|| anyhow!("no such worktree"))?;
1085 if !worktree.collaborator_user_ids.contains(&user_id) {
1086 Err(anyhow!("no such worktree"))?;
1087 }
1088
1089 let share = worktree.share_mut()?;
1090 connection.worktrees.insert(worktree_id);
1091
1092 let mut replica_id = 1;
1093 while share.active_replica_ids.contains(&replica_id) {
1094 replica_id += 1;
1095 }
1096 share.active_replica_ids.insert(replica_id);
1097 share.guest_connection_ids.insert(connection_id, replica_id);
1098 return Ok((replica_id, worktree));
1099 }
1100
1101 fn read_worktree(
1102 &self,
1103 worktree_id: u64,
1104 connection_id: ConnectionId,
1105 ) -> tide::Result<&Worktree> {
1106 let worktree = self
1107 .worktrees
1108 .get(&worktree_id)
1109 .ok_or_else(|| anyhow!("worktree not found"))?;
1110
1111 if worktree.host_connection_id == connection_id
1112 || worktree
1113 .share()?
1114 .guest_connection_ids
1115 .contains_key(&connection_id)
1116 {
1117 Ok(worktree)
1118 } else {
1119 Err(anyhow!(
1120 "{} is not a member of worktree {}",
1121 connection_id,
1122 worktree_id
1123 ))?
1124 }
1125 }
1126
1127 fn write_worktree(
1128 &mut self,
1129 worktree_id: u64,
1130 connection_id: ConnectionId,
1131 ) -> tide::Result<&mut Worktree> {
1132 let worktree = self
1133 .worktrees
1134 .get_mut(&worktree_id)
1135 .ok_or_else(|| anyhow!("worktree not found"))?;
1136
1137 if worktree.host_connection_id == connection_id
1138 || worktree.share.as_ref().map_or(false, |share| {
1139 share.guest_connection_ids.contains_key(&connection_id)
1140 })
1141 {
1142 Ok(worktree)
1143 } else {
1144 Err(anyhow!(
1145 "{} is not a member of worktree {}",
1146 connection_id,
1147 worktree_id
1148 ))?
1149 }
1150 }
1151
1152 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1153 let worktree_id = self.next_worktree_id;
1154 for collaborator_user_id in &worktree.collaborator_user_ids {
1155 self.visible_worktrees_by_user_id
1156 .entry(*collaborator_user_id)
1157 .or_default()
1158 .insert(worktree_id);
1159 }
1160 self.next_worktree_id += 1;
1161 self.worktrees.insert(worktree_id, worktree);
1162 worktree_id
1163 }
1164
1165 fn remove_worktree(&mut self, worktree_id: u64) {
1166 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1167 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1168 connection.worktrees.remove(&worktree_id);
1169 }
1170 if let Some(share) = worktree.share {
1171 for connection_id in share.guest_connection_ids.keys() {
1172 if let Some(connection) = self.connections.get_mut(connection_id) {
1173 connection.worktrees.remove(&worktree_id);
1174 }
1175 }
1176 }
1177 for collaborator_user_id in worktree.collaborator_user_ids {
1178 if let Some(visible_worktrees) = self
1179 .visible_worktrees_by_user_id
1180 .get_mut(&collaborator_user_id)
1181 {
1182 visible_worktrees.remove(&worktree_id);
1183 }
1184 }
1185 }
1186}
1187
1188impl Worktree {
1189 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1190 if let Some(share) = &self.share {
1191 share
1192 .guest_connection_ids
1193 .keys()
1194 .copied()
1195 .chain(Some(self.host_connection_id))
1196 .collect()
1197 } else {
1198 vec![self.host_connection_id]
1199 }
1200 }
1201
1202 fn share(&self) -> tide::Result<&WorktreeShare> {
1203 Ok(self
1204 .share
1205 .as_ref()
1206 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1207 }
1208
1209 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1210 Ok(self
1211 .share
1212 .as_mut()
1213 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1214 }
1215}
1216
1217impl Channel {
1218 fn connection_ids(&self) -> Vec<ConnectionId> {
1219 self.connection_ids.iter().copied().collect()
1220 }
1221}
1222
1223pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1224 let server = Server::new(app.state().clone(), rpc.clone(), None);
1225 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1226 let user_id = request.ext::<UserId>().copied();
1227 let server = server.clone();
1228 async move {
1229 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1230
1231 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1232 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1233 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1234
1235 if !upgrade_requested {
1236 return Ok(Response::new(StatusCode::UpgradeRequired));
1237 }
1238
1239 let header = match request.header("Sec-Websocket-Key") {
1240 Some(h) => h.as_str(),
1241 None => return Err(anyhow!("expected sec-websocket-key"))?,
1242 };
1243
1244 let mut response = Response::new(StatusCode::SwitchingProtocols);
1245 response.insert_header(UPGRADE, "websocket");
1246 response.insert_header(CONNECTION, "Upgrade");
1247 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1248 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1249 response.insert_header("Sec-Websocket-Version", "13");
1250
1251 let http_res: &mut tide::http::Response = response.as_mut();
1252 let upgrade_receiver = http_res.recv_upgrade().await;
1253 let addr = request.remote().unwrap_or("unknown").to_string();
1254 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1255 task::spawn(async move {
1256 if let Some(stream) = upgrade_receiver.await {
1257 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1258 }
1259 });
1260
1261 Ok(response)
1262 }
1263 });
1264}
1265
1266fn header_contains_ignore_case<T>(
1267 request: &tide::Request<T>,
1268 header_name: HeaderName,
1269 value: &str,
1270) -> bool {
1271 request
1272 .header(header_name)
1273 .map(|h| {
1274 h.as_str()
1275 .split(',')
1276 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1277 })
1278 .unwrap_or(false)
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283 use super::*;
1284 use crate::{
1285 auth,
1286 db::{tests::TestDb, UserId},
1287 github, AppState, Config,
1288 };
1289 use async_std::{sync::RwLockReadGuard, task};
1290 use gpui::TestAppContext;
1291 use parking_lot::Mutex;
1292 use postage::{mpsc, watch};
1293 use serde_json::json;
1294 use sqlx::types::time::OffsetDateTime;
1295 use std::{
1296 path::Path,
1297 sync::{
1298 atomic::{AtomicBool, Ordering::SeqCst},
1299 Arc,
1300 },
1301 time::Duration,
1302 };
1303 use zed::{
1304 channel::{Channel, ChannelDetails, ChannelList},
1305 editor::{Editor, Insert},
1306 fs::{FakeFs, Fs as _},
1307 language::LanguageRegistry,
1308 rpc::{self, Client, Credentials, EstablishConnectionError},
1309 settings,
1310 test::FakeHttpClient,
1311 user::UserStore,
1312 worktree::Worktree,
1313 };
1314 use zrpc::Peer;
1315
1316 #[gpui::test]
1317 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1318 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1319 let settings = cx_b.read(settings::test).1;
1320 let lang_registry = Arc::new(LanguageRegistry::new());
1321
1322 // Connect to a server as 2 clients.
1323 let mut server = TestServer::start().await;
1324 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1325 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1326
1327 cx_a.foreground().forbid_parking();
1328
1329 // Share a local worktree as client A
1330 let fs = Arc::new(FakeFs::new());
1331 fs.insert_tree(
1332 "/a",
1333 json!({
1334 ".zed.toml": r#"collaborators = ["user_b"]"#,
1335 "a.txt": "a-contents",
1336 "b.txt": "b-contents",
1337 }),
1338 )
1339 .await;
1340 let worktree_a = Worktree::open_local(
1341 client_a.clone(),
1342 "/a".as_ref(),
1343 fs,
1344 lang_registry.clone(),
1345 &mut cx_a.to_async(),
1346 )
1347 .await
1348 .unwrap();
1349 worktree_a
1350 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1351 .await;
1352 let worktree_id = worktree_a
1353 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1354 .await
1355 .unwrap();
1356
1357 // Join that worktree as client B, and see that a guest has joined as client A.
1358 let worktree_b = Worktree::open_remote(
1359 client_b.clone(),
1360 worktree_id,
1361 lang_registry.clone(),
1362 &mut cx_b.to_async(),
1363 )
1364 .await
1365 .unwrap();
1366 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1367 worktree_a
1368 .condition(&cx_a, |tree, _| {
1369 tree.peers()
1370 .values()
1371 .any(|replica_id| *replica_id == replica_id_b)
1372 })
1373 .await;
1374
1375 // Open the same file as client B and client A.
1376 let buffer_b = worktree_b
1377 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1378 .await
1379 .unwrap();
1380 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1381 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1382 let buffer_a = worktree_a
1383 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1384 .await
1385 .unwrap();
1386
1387 // Create a selection set as client B and see that selection set as client A.
1388 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1389 buffer_a
1390 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1391 .await;
1392
1393 // Edit the buffer as client B and see that edit as client A.
1394 editor_b.update(&mut cx_b, |editor, cx| {
1395 editor.insert(&Insert("ok, ".into()), cx)
1396 });
1397 buffer_a
1398 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1399 .await;
1400
1401 // Remove the selection set as client B, see those selections disappear as client A.
1402 cx_b.update(move |_| drop(editor_b));
1403 buffer_a
1404 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1405 .await;
1406
1407 // Close the buffer as client A, see that the buffer is closed.
1408 cx_a.update(move |_| drop(buffer_a));
1409 worktree_a
1410 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1411 .await;
1412
1413 // Dropping the worktree removes client B from client A's peers.
1414 cx_b.update(move |_| drop(worktree_b));
1415 worktree_a
1416 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1417 .await;
1418 }
1419
1420 #[gpui::test]
1421 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1422 mut cx_a: TestAppContext,
1423 mut cx_b: TestAppContext,
1424 mut cx_c: TestAppContext,
1425 ) {
1426 cx_a.foreground().forbid_parking();
1427 let lang_registry = Arc::new(LanguageRegistry::new());
1428
1429 // Connect to a server as 3 clients.
1430 let mut server = TestServer::start().await;
1431 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1432 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1433 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1434
1435 let fs = Arc::new(FakeFs::new());
1436
1437 // Share a worktree as client A.
1438 fs.insert_tree(
1439 "/a",
1440 json!({
1441 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1442 "file1": "",
1443 "file2": ""
1444 }),
1445 )
1446 .await;
1447
1448 let worktree_a = Worktree::open_local(
1449 client_a.clone(),
1450 "/a".as_ref(),
1451 fs.clone(),
1452 lang_registry.clone(),
1453 &mut cx_a.to_async(),
1454 )
1455 .await
1456 .unwrap();
1457 worktree_a
1458 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1459 .await;
1460 let worktree_id = worktree_a
1461 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1462 .await
1463 .unwrap();
1464
1465 // Join that worktree as clients B and C.
1466 let worktree_b = Worktree::open_remote(
1467 client_b.clone(),
1468 worktree_id,
1469 lang_registry.clone(),
1470 &mut cx_b.to_async(),
1471 )
1472 .await
1473 .unwrap();
1474 let worktree_c = Worktree::open_remote(
1475 client_c.clone(),
1476 worktree_id,
1477 lang_registry.clone(),
1478 &mut cx_c.to_async(),
1479 )
1480 .await
1481 .unwrap();
1482
1483 // Open and edit a buffer as both guests B and C.
1484 let buffer_b = worktree_b
1485 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1486 .await
1487 .unwrap();
1488 let buffer_c = worktree_c
1489 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1490 .await
1491 .unwrap();
1492 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1493 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1494
1495 // Open and edit that buffer as the host.
1496 let buffer_a = worktree_a
1497 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1498 .await
1499 .unwrap();
1500
1501 buffer_a
1502 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1503 .await;
1504 buffer_a.update(&mut cx_a, |buf, cx| {
1505 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1506 });
1507
1508 // Wait for edits to propagate
1509 buffer_a
1510 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1511 .await;
1512 buffer_b
1513 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1514 .await;
1515 buffer_c
1516 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1517 .await;
1518
1519 // Edit the buffer as the host and concurrently save as guest B.
1520 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1521 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1522 save_b.await.unwrap();
1523 assert_eq!(
1524 fs.load("/a/file1".as_ref()).await.unwrap(),
1525 "hi-a, i-am-c, i-am-b, i-am-a"
1526 );
1527 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1528 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1529 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1530
1531 // Make changes on host's file system, see those changes on the guests.
1532 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1533 .await
1534 .unwrap();
1535 fs.insert_file(Path::new("/a/file4"), "4".into())
1536 .await
1537 .unwrap();
1538
1539 worktree_b
1540 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1541 .await;
1542 worktree_c
1543 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1544 .await;
1545 worktree_b.read_with(&cx_b, |tree, _| {
1546 assert_eq!(
1547 tree.paths()
1548 .map(|p| p.to_string_lossy())
1549 .collect::<Vec<_>>(),
1550 &[".zed.toml", "file1", "file3", "file4"]
1551 )
1552 });
1553 worktree_c.read_with(&cx_c, |tree, _| {
1554 assert_eq!(
1555 tree.paths()
1556 .map(|p| p.to_string_lossy())
1557 .collect::<Vec<_>>(),
1558 &[".zed.toml", "file1", "file3", "file4"]
1559 )
1560 });
1561 }
1562
1563 #[gpui::test]
1564 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1565 cx_a.foreground().forbid_parking();
1566 let lang_registry = Arc::new(LanguageRegistry::new());
1567
1568 // Connect to a server as 2 clients.
1569 let mut server = TestServer::start().await;
1570 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1571 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1572
1573 // Share a local worktree as client A
1574 let fs = Arc::new(FakeFs::new());
1575 fs.insert_tree(
1576 "/dir",
1577 json!({
1578 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1579 "a.txt": "a-contents",
1580 }),
1581 )
1582 .await;
1583
1584 let worktree_a = Worktree::open_local(
1585 client_a.clone(),
1586 "/dir".as_ref(),
1587 fs,
1588 lang_registry.clone(),
1589 &mut cx_a.to_async(),
1590 )
1591 .await
1592 .unwrap();
1593 worktree_a
1594 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1595 .await;
1596 let worktree_id = worktree_a
1597 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1598 .await
1599 .unwrap();
1600
1601 // Join that worktree as client B, and see that a guest has joined as client A.
1602 let worktree_b = Worktree::open_remote(
1603 client_b.clone(),
1604 worktree_id,
1605 lang_registry.clone(),
1606 &mut cx_b.to_async(),
1607 )
1608 .await
1609 .unwrap();
1610
1611 let buffer_b = worktree_b
1612 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1613 .await
1614 .unwrap();
1615 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1616
1617 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1618 buffer_b.read_with(&cx_b, |buf, _| {
1619 assert!(buf.is_dirty());
1620 assert!(!buf.has_conflict());
1621 });
1622
1623 buffer_b
1624 .update(&mut cx_b, |buf, cx| buf.save(cx))
1625 .unwrap()
1626 .await
1627 .unwrap();
1628 worktree_b
1629 .condition(&cx_b, |_, cx| {
1630 buffer_b.read(cx).file().unwrap().mtime != mtime
1631 })
1632 .await;
1633 buffer_b.read_with(&cx_b, |buf, _| {
1634 assert!(!buf.is_dirty());
1635 assert!(!buf.has_conflict());
1636 });
1637
1638 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1639 buffer_b.read_with(&cx_b, |buf, _| {
1640 assert!(buf.is_dirty());
1641 assert!(!buf.has_conflict());
1642 });
1643 }
1644
1645 #[gpui::test]
1646 async fn test_editing_while_guest_opens_buffer(
1647 mut cx_a: TestAppContext,
1648 mut cx_b: TestAppContext,
1649 ) {
1650 cx_a.foreground().forbid_parking();
1651 let lang_registry = Arc::new(LanguageRegistry::new());
1652
1653 // Connect to a server as 2 clients.
1654 let mut server = TestServer::start().await;
1655 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1656 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1657
1658 // Share a local worktree as client A
1659 let fs = Arc::new(FakeFs::new());
1660 fs.insert_tree(
1661 "/dir",
1662 json!({
1663 ".zed.toml": r#"collaborators = ["user_b"]"#,
1664 "a.txt": "a-contents",
1665 }),
1666 )
1667 .await;
1668 let worktree_a = Worktree::open_local(
1669 client_a.clone(),
1670 "/dir".as_ref(),
1671 fs,
1672 lang_registry.clone(),
1673 &mut cx_a.to_async(),
1674 )
1675 .await
1676 .unwrap();
1677 worktree_a
1678 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1679 .await;
1680 let worktree_id = worktree_a
1681 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1682 .await
1683 .unwrap();
1684
1685 // Join that worktree as client B, and see that a guest has joined as client A.
1686 let worktree_b = Worktree::open_remote(
1687 client_b.clone(),
1688 worktree_id,
1689 lang_registry.clone(),
1690 &mut cx_b.to_async(),
1691 )
1692 .await
1693 .unwrap();
1694
1695 let buffer_a = worktree_a
1696 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1697 .await
1698 .unwrap();
1699 let buffer_b = cx_b
1700 .background()
1701 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1702
1703 task::yield_now().await;
1704 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1705
1706 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1707 let buffer_b = buffer_b.await.unwrap();
1708 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1709 }
1710
1711 #[gpui::test]
1712 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1713 cx_a.foreground().forbid_parking();
1714 let lang_registry = Arc::new(LanguageRegistry::new());
1715
1716 // Connect to a server as 2 clients.
1717 let mut server = TestServer::start().await;
1718 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1719 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1720
1721 // Share a local worktree as client A
1722 let fs = Arc::new(FakeFs::new());
1723 fs.insert_tree(
1724 "/a",
1725 json!({
1726 ".zed.toml": r#"collaborators = ["user_b"]"#,
1727 "a.txt": "a-contents",
1728 "b.txt": "b-contents",
1729 }),
1730 )
1731 .await;
1732 let worktree_a = Worktree::open_local(
1733 client_a.clone(),
1734 "/a".as_ref(),
1735 fs,
1736 lang_registry.clone(),
1737 &mut cx_a.to_async(),
1738 )
1739 .await
1740 .unwrap();
1741 worktree_a
1742 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1743 .await;
1744 let worktree_id = worktree_a
1745 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1746 .await
1747 .unwrap();
1748
1749 // Join that worktree as client B, and see that a guest has joined as client A.
1750 let _worktree_b = Worktree::open_remote(
1751 client_b.clone(),
1752 worktree_id,
1753 lang_registry.clone(),
1754 &mut cx_b.to_async(),
1755 )
1756 .await
1757 .unwrap();
1758 worktree_a
1759 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1760 .await;
1761
1762 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1763 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1764 worktree_a
1765 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1766 .await;
1767 }
1768
1769 #[gpui::test]
1770 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1771 cx_a.foreground().forbid_parking();
1772
1773 // Connect to a server as 2 clients.
1774 let mut server = TestServer::start().await;
1775 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1776 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1777
1778 // Create an org that includes these 2 users.
1779 let db = &server.app_state.db;
1780 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1781 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1782 .await
1783 .unwrap();
1784 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1785 .await
1786 .unwrap();
1787
1788 // Create a channel that includes all the users.
1789 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1790 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1791 .await
1792 .unwrap();
1793 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1794 .await
1795 .unwrap();
1796 db.create_channel_message(
1797 channel_id,
1798 current_user_id(&user_store_b),
1799 "hello A, it's B.",
1800 OffsetDateTime::now_utc(),
1801 1,
1802 )
1803 .await
1804 .unwrap();
1805
1806 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1807 channels_a
1808 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1809 .await;
1810 channels_a.read_with(&cx_a, |list, _| {
1811 assert_eq!(
1812 list.available_channels().unwrap(),
1813 &[ChannelDetails {
1814 id: channel_id.to_proto(),
1815 name: "test-channel".to_string()
1816 }]
1817 )
1818 });
1819 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1820 this.get_channel(channel_id.to_proto(), cx).unwrap()
1821 });
1822 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1823 channel_a
1824 .condition(&cx_a, |channel, _| {
1825 channel_messages(channel)
1826 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1827 })
1828 .await;
1829
1830 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1831 channels_b
1832 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1833 .await;
1834 channels_b.read_with(&cx_b, |list, _| {
1835 assert_eq!(
1836 list.available_channels().unwrap(),
1837 &[ChannelDetails {
1838 id: channel_id.to_proto(),
1839 name: "test-channel".to_string()
1840 }]
1841 )
1842 });
1843
1844 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1845 this.get_channel(channel_id.to_proto(), cx).unwrap()
1846 });
1847 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1848 channel_b
1849 .condition(&cx_b, |channel, _| {
1850 channel_messages(channel)
1851 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1852 })
1853 .await;
1854
1855 channel_a
1856 .update(&mut cx_a, |channel, cx| {
1857 channel
1858 .send_message("oh, hi B.".to_string(), cx)
1859 .unwrap()
1860 .detach();
1861 let task = channel.send_message("sup".to_string(), cx).unwrap();
1862 assert_eq!(
1863 channel_messages(channel),
1864 &[
1865 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1866 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1867 ("user_a".to_string(), "sup".to_string(), true)
1868 ]
1869 );
1870 task
1871 })
1872 .await
1873 .unwrap();
1874
1875 channel_b
1876 .condition(&cx_b, |channel, _| {
1877 channel_messages(channel)
1878 == [
1879 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1880 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1881 ("user_a".to_string(), "sup".to_string(), false),
1882 ]
1883 })
1884 .await;
1885
1886 assert_eq!(
1887 server.state().await.channels[&channel_id]
1888 .connection_ids
1889 .len(),
1890 2
1891 );
1892 cx_b.update(|_| drop(channel_b));
1893 server
1894 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1895 .await;
1896
1897 cx_a.update(|_| drop(channel_a));
1898 server
1899 .condition(|state| !state.channels.contains_key(&channel_id))
1900 .await;
1901 }
1902
1903 #[gpui::test]
1904 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1905 cx_a.foreground().forbid_parking();
1906
1907 let mut server = TestServer::start().await;
1908 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1909
1910 let db = &server.app_state.db;
1911 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1912 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1913 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1914 .await
1915 .unwrap();
1916 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1917 .await
1918 .unwrap();
1919
1920 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1921 channels_a
1922 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1923 .await;
1924 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1925 this.get_channel(channel_id.to_proto(), cx).unwrap()
1926 });
1927
1928 // Messages aren't allowed to be too long.
1929 channel_a
1930 .update(&mut cx_a, |channel, cx| {
1931 let long_body = "this is long.\n".repeat(1024);
1932 channel.send_message(long_body, cx).unwrap()
1933 })
1934 .await
1935 .unwrap_err();
1936
1937 // Messages aren't allowed to be blank.
1938 channel_a.update(&mut cx_a, |channel, cx| {
1939 channel.send_message(String::new(), cx).unwrap_err()
1940 });
1941
1942 // Leading and trailing whitespace are trimmed.
1943 channel_a
1944 .update(&mut cx_a, |channel, cx| {
1945 channel
1946 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1947 .unwrap()
1948 })
1949 .await
1950 .unwrap();
1951 assert_eq!(
1952 db.get_channel_messages(channel_id, 10, None)
1953 .await
1954 .unwrap()
1955 .iter()
1956 .map(|m| &m.body)
1957 .collect::<Vec<_>>(),
1958 &["surrounded by whitespace"]
1959 );
1960 }
1961
1962 #[gpui::test]
1963 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1964 cx_a.foreground().forbid_parking();
1965 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1966
1967 // Connect to a server as 2 clients.
1968 let mut server = TestServer::start().await;
1969 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1970 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1971 let mut status_b = client_b.status();
1972
1973 // Create an org that includes these 2 users.
1974 let db = &server.app_state.db;
1975 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1976 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1977 .await
1978 .unwrap();
1979 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1980 .await
1981 .unwrap();
1982
1983 // Create a channel that includes all the users.
1984 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1985 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1986 .await
1987 .unwrap();
1988 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1989 .await
1990 .unwrap();
1991 db.create_channel_message(
1992 channel_id,
1993 current_user_id(&user_store_b),
1994 "hello A, it's B.",
1995 OffsetDateTime::now_utc(),
1996 2,
1997 )
1998 .await
1999 .unwrap();
2000
2001 let user_store_a =
2002 UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
2003 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
2004 channels_a
2005 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
2006 .await;
2007
2008 channels_a.read_with(&cx_a, |list, _| {
2009 assert_eq!(
2010 list.available_channels().unwrap(),
2011 &[ChannelDetails {
2012 id: channel_id.to_proto(),
2013 name: "test-channel".to_string()
2014 }]
2015 )
2016 });
2017 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2018 this.get_channel(channel_id.to_proto(), cx).unwrap()
2019 });
2020 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2021 channel_a
2022 .condition(&cx_a, |channel, _| {
2023 channel_messages(channel)
2024 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2025 })
2026 .await;
2027
2028 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2029 channels_b
2030 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2031 .await;
2032 channels_b.read_with(&cx_b, |list, _| {
2033 assert_eq!(
2034 list.available_channels().unwrap(),
2035 &[ChannelDetails {
2036 id: channel_id.to_proto(),
2037 name: "test-channel".to_string()
2038 }]
2039 )
2040 });
2041
2042 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2043 this.get_channel(channel_id.to_proto(), cx).unwrap()
2044 });
2045 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2046 channel_b
2047 .condition(&cx_b, |channel, _| {
2048 channel_messages(channel)
2049 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2050 })
2051 .await;
2052
2053 // Disconnect client B, ensuring we can still access its cached channel data.
2054 server.forbid_connections();
2055 server.disconnect_client(current_user_id(&user_store_b));
2056 while !matches!(
2057 status_b.recv().await,
2058 Some(rpc::Status::ReconnectionError { .. })
2059 ) {}
2060
2061 channels_b.read_with(&cx_b, |channels, _| {
2062 assert_eq!(
2063 channels.available_channels().unwrap(),
2064 [ChannelDetails {
2065 id: channel_id.to_proto(),
2066 name: "test-channel".to_string()
2067 }]
2068 )
2069 });
2070 channel_b.read_with(&cx_b, |channel, _| {
2071 assert_eq!(
2072 channel_messages(channel),
2073 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2074 )
2075 });
2076
2077 // Send a message from client B while it is disconnected.
2078 channel_b
2079 .update(&mut cx_b, |channel, cx| {
2080 let task = channel
2081 .send_message("can you see this?".to_string(), cx)
2082 .unwrap();
2083 assert_eq!(
2084 channel_messages(channel),
2085 &[
2086 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2087 ("user_b".to_string(), "can you see this?".to_string(), true)
2088 ]
2089 );
2090 task
2091 })
2092 .await
2093 .unwrap_err();
2094
2095 // Send a message from client A while B is disconnected.
2096 channel_a
2097 .update(&mut cx_a, |channel, cx| {
2098 channel
2099 .send_message("oh, hi B.".to_string(), cx)
2100 .unwrap()
2101 .detach();
2102 let task = channel.send_message("sup".to_string(), cx).unwrap();
2103 assert_eq!(
2104 channel_messages(channel),
2105 &[
2106 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2107 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2108 ("user_a".to_string(), "sup".to_string(), true)
2109 ]
2110 );
2111 task
2112 })
2113 .await
2114 .unwrap();
2115
2116 // Give client B a chance to reconnect.
2117 server.allow_connections();
2118 cx_b.foreground().advance_clock(Duration::from_secs(10));
2119
2120 // Verify that B sees the new messages upon reconnection, as well as the message client B
2121 // sent while offline.
2122 channel_b
2123 .condition(&cx_b, |channel, _| {
2124 channel_messages(channel)
2125 == [
2126 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2127 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2128 ("user_a".to_string(), "sup".to_string(), false),
2129 ("user_b".to_string(), "can you see this?".to_string(), false),
2130 ]
2131 })
2132 .await;
2133
2134 // Ensure client A and B can communicate normally after reconnection.
2135 channel_a
2136 .update(&mut cx_a, |channel, cx| {
2137 channel.send_message("you online?".to_string(), cx).unwrap()
2138 })
2139 .await
2140 .unwrap();
2141 channel_b
2142 .condition(&cx_b, |channel, _| {
2143 channel_messages(channel)
2144 == [
2145 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2146 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2147 ("user_a".to_string(), "sup".to_string(), false),
2148 ("user_b".to_string(), "can you see this?".to_string(), false),
2149 ("user_a".to_string(), "you online?".to_string(), false),
2150 ]
2151 })
2152 .await;
2153
2154 channel_b
2155 .update(&mut cx_b, |channel, cx| {
2156 channel.send_message("yep".to_string(), cx).unwrap()
2157 })
2158 .await
2159 .unwrap();
2160 channel_a
2161 .condition(&cx_a, |channel, _| {
2162 channel_messages(channel)
2163 == [
2164 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2165 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2166 ("user_a".to_string(), "sup".to_string(), false),
2167 ("user_b".to_string(), "can you see this?".to_string(), false),
2168 ("user_a".to_string(), "you online?".to_string(), false),
2169 ("user_b".to_string(), "yep".to_string(), false),
2170 ]
2171 })
2172 .await;
2173 }
2174
2175 struct TestServer {
2176 peer: Arc<Peer>,
2177 app_state: Arc<AppState>,
2178 server: Arc<Server>,
2179 notifications: mpsc::Receiver<()>,
2180 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2181 forbid_connections: Arc<AtomicBool>,
2182 _test_db: TestDb,
2183 }
2184
2185 impl TestServer {
2186 async fn start() -> Self {
2187 let test_db = TestDb::new();
2188 let app_state = Self::build_app_state(&test_db).await;
2189 let peer = Peer::new();
2190 let notifications = mpsc::channel(128);
2191 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2192 Self {
2193 peer,
2194 app_state,
2195 server,
2196 notifications: notifications.1,
2197 connection_killers: Default::default(),
2198 forbid_connections: Default::default(),
2199 _test_db: test_db,
2200 }
2201 }
2202
2203 async fn create_client(
2204 &mut self,
2205 cx: &mut TestAppContext,
2206 name: &str,
2207 ) -> (Arc<Client>, Arc<UserStore>) {
2208 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2209 let client_name = name.to_string();
2210 let mut client = Client::new();
2211 let server = self.server.clone();
2212 let connection_killers = self.connection_killers.clone();
2213 let forbid_connections = self.forbid_connections.clone();
2214 Arc::get_mut(&mut client)
2215 .unwrap()
2216 .override_authenticate(move |cx| {
2217 cx.spawn(|_| async move {
2218 let access_token = "the-token".to_string();
2219 Ok(Credentials {
2220 user_id: user_id.0 as u64,
2221 access_token,
2222 })
2223 })
2224 })
2225 .override_establish_connection(move |credentials, cx| {
2226 assert_eq!(credentials.user_id, user_id.0 as u64);
2227 assert_eq!(credentials.access_token, "the-token");
2228
2229 let server = server.clone();
2230 let connection_killers = connection_killers.clone();
2231 let forbid_connections = forbid_connections.clone();
2232 let client_name = client_name.clone();
2233 cx.spawn(move |cx| async move {
2234 if forbid_connections.load(SeqCst) {
2235 Err(EstablishConnectionError::other(anyhow!(
2236 "server is forbidding connections"
2237 )))
2238 } else {
2239 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2240 connection_killers.lock().insert(user_id, kill_conn);
2241 cx.background()
2242 .spawn(server.handle_connection(server_conn, client_name, user_id))
2243 .detach();
2244 Ok(client_conn)
2245 }
2246 })
2247 });
2248
2249 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2250 client
2251 .authenticate_and_connect(&cx.to_async())
2252 .await
2253 .unwrap();
2254
2255 let user_store = UserStore::new(client.clone(), http, &cx.background());
2256 let mut authed_user = user_store.watch_current_user();
2257 while authed_user.recv().await.unwrap().is_none() {}
2258
2259 (client, user_store)
2260 }
2261
2262 fn disconnect_client(&self, user_id: UserId) {
2263 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2264 let _ = kill_conn.try_send(Some(()));
2265 }
2266 }
2267
2268 fn forbid_connections(&self) {
2269 self.forbid_connections.store(true, SeqCst);
2270 }
2271
2272 fn allow_connections(&self) {
2273 self.forbid_connections.store(false, SeqCst);
2274 }
2275
2276 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2277 let mut config = Config::default();
2278 config.session_secret = "a".repeat(32);
2279 config.database_url = test_db.url.clone();
2280 let github_client = github::AppClient::test();
2281 Arc::new(AppState {
2282 db: test_db.db().clone(),
2283 handlebars: Default::default(),
2284 auth_client: auth::build_client("", ""),
2285 repo_client: github::RepoClient::test(&github_client),
2286 github_client,
2287 config,
2288 })
2289 }
2290
2291 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2292 self.server.state.read().await
2293 }
2294
2295 async fn condition<F>(&mut self, mut predicate: F)
2296 where
2297 F: FnMut(&ServerState) -> bool,
2298 {
2299 async_std::future::timeout(Duration::from_millis(500), async {
2300 while !(predicate)(&*self.server.state.read().await) {
2301 self.notifications.recv().await;
2302 }
2303 })
2304 .await
2305 .expect("condition timed out");
2306 }
2307 }
2308
2309 impl Drop for TestServer {
2310 fn drop(&mut self) {
2311 task::block_on(self.peer.reset());
2312 }
2313 }
2314
2315 fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2316 UserId::from_proto(user_store.current_user().unwrap().id)
2317 }
2318
2319 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2320 channel
2321 .messages()
2322 .cursor::<(), ()>()
2323 .map(|m| {
2324 (
2325 m.sender.github_login.clone(),
2326 m.body.clone(),
2327 m.is_pending(),
2328 )
2329 })
2330 .collect()
2331 }
2332
2333 struct EmptyView;
2334
2335 impl gpui::Entity for EmptyView {
2336 type Event = ();
2337 }
2338
2339 impl gpui::View for EmptyView {
2340 fn ui_name() -> &'static str {
2341 "empty view"
2342 }
2343
2344 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2345 gpui::Element::boxed(gpui::elements::Empty)
2346 }
2347 }
2348}