1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
29 Connection, ConnectionId, Peer, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34type MessageHandler = Box<
35 dyn Send
36 + Sync
37 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
38>;
39
40pub struct Server {
41 peer: Arc<Peer>,
42 state: RwLock<ServerState>,
43 app_state: Arc<AppState>,
44 handlers: HashMap<TypeId, MessageHandler>,
45 notifications: Option<mpsc::Sender<()>>,
46}
47
48#[derive(Default)]
49struct ServerState {
50 connections: HashMap<ConnectionId, ConnectionState>,
51 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
52 pub worktrees: HashMap<u64, Worktree>,
53 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
54 channels: HashMap<ChannelId, Channel>,
55 next_worktree_id: u64,
56}
57
58struct ConnectionState {
59 user_id: UserId,
60 worktrees: HashSet<u64>,
61 channels: HashSet<ChannelId>,
62}
63
64struct Worktree {
65 host_connection_id: ConnectionId,
66 collaborator_user_ids: Vec<UserId>,
67 root_name: String,
68 share: Option<WorktreeShare>,
69}
70
71struct WorktreeShare {
72 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
73 active_replica_ids: HashSet<ReplicaId>,
74 entries: HashMap<u64, proto::Entry>,
75}
76
77#[derive(Default)]
78struct Channel {
79 connection_ids: HashSet<ConnectionId>,
80}
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84
85impl Server {
86 pub fn new(
87 app_state: Arc<AppState>,
88 peer: Arc<Peer>,
89 notifications: Option<mpsc::Sender<()>>,
90 ) -> Arc<Self> {
91 let mut server = Self {
92 peer,
93 app_state,
94 state: Default::default(),
95 handlers: Default::default(),
96 notifications,
97 };
98
99 server
100 .add_handler(Server::ping)
101 .add_handler(Server::open_worktree)
102 .add_handler(Server::handle_close_worktree)
103 .add_handler(Server::share_worktree)
104 .add_handler(Server::unshare_worktree)
105 .add_handler(Server::join_worktree)
106 .add_handler(Server::update_worktree)
107 .add_handler(Server::open_buffer)
108 .add_handler(Server::close_buffer)
109 .add_handler(Server::update_buffer)
110 .add_handler(Server::buffer_saved)
111 .add_handler(Server::save_buffer)
112 .add_handler(Server::get_channels)
113 .add_handler(Server::get_users)
114 .add_handler(Server::join_channel)
115 .add_handler(Server::leave_channel)
116 .add_handler(Server::send_channel_message)
117 .add_handler(Server::get_channel_messages);
118
119 Arc::new(server)
120 }
121
122 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
123 where
124 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
125 Fut: 'static + Send + Future<Output = tide::Result<()>>,
126 M: EnvelopedMessage,
127 {
128 let prev_handler = self.handlers.insert(
129 TypeId::of::<M>(),
130 Box::new(move |server, envelope| {
131 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
132 (handler)(server, *envelope).boxed()
133 }),
134 );
135 if prev_handler.is_some() {
136 panic!("registered a handler for the same message twice");
137 }
138 self
139 }
140
141 pub fn handle_connection(
142 self: &Arc<Self>,
143 connection: Connection,
144 addr: String,
145 user_id: UserId,
146 ) -> impl Future<Output = ()> {
147 let this = self.clone();
148 async move {
149 let (connection_id, handle_io, mut incoming_rx) =
150 this.peer.add_connection(connection).await;
151 this.add_connection(connection_id, user_id).await;
152
153 let handle_io = handle_io.fuse();
154 futures::pin_mut!(handle_io);
155 loop {
156 let next_message = incoming_rx.recv().fuse();
157 futures::pin_mut!(next_message);
158 futures::select_biased! {
159 message = next_message => {
160 if let Some(message) = message {
161 let start_time = Instant::now();
162 log::info!("RPC message received: {}", message.payload_type_name());
163 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
164 if let Err(err) = (handler)(this.clone(), message).await {
165 log::error!("error handling message: {:?}", err);
166 } else {
167 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
168 }
169
170 if let Some(mut notifications) = this.notifications.clone() {
171 let _ = notifications.send(()).await;
172 }
173 } else {
174 log::warn!("unhandled message: {}", message.payload_type_name());
175 }
176 } else {
177 log::info!("rpc connection closed {:?}", addr);
178 break;
179 }
180 }
181 handle_io = handle_io => {
182 if let Err(err) = handle_io {
183 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
184 }
185 break;
186 }
187 }
188 }
189
190 if let Err(err) = this.sign_out(connection_id).await {
191 log::error!("error signing out connection {:?} - {:?}", addr, err);
192 }
193 }
194 }
195
196 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
197 self.peer.disconnect(connection_id).await;
198 self.remove_connection(connection_id).await?;
199 Ok(())
200 }
201
202 // Add a new connection associated with a given user.
203 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
204 let mut state = self.state.write().await;
205 state.connections.insert(
206 connection_id,
207 ConnectionState {
208 user_id,
209 worktrees: Default::default(),
210 channels: Default::default(),
211 },
212 );
213 state
214 .connections_by_user_id
215 .entry(user_id)
216 .or_default()
217 .insert(connection_id);
218 }
219
220 // Remove the given connection and its association with any worktrees.
221 async fn remove_connection(
222 self: &Arc<Server>,
223 connection_id: ConnectionId,
224 ) -> tide::Result<()> {
225 let mut worktree_ids = Vec::new();
226 let mut state = self.state.write().await;
227 if let Some(connection) = state.connections.remove(&connection_id) {
228 worktree_ids = connection.worktrees.into_iter().collect();
229
230 for channel_id in connection.channels {
231 if let Some(channel) = state.channels.get_mut(&channel_id) {
232 channel.connection_ids.remove(&connection_id);
233 }
234 }
235
236 let user_connections = state
237 .connections_by_user_id
238 .get_mut(&connection.user_id)
239 .unwrap();
240 user_connections.remove(&connection_id);
241 if user_connections.is_empty() {
242 state.connections_by_user_id.remove(&connection.user_id);
243 }
244 }
245
246 drop(state);
247 for worktree_id in worktree_ids {
248 self.close_worktree(worktree_id, connection_id).await?;
249 }
250
251 Ok(())
252 }
253
254 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
255 self.peer.respond(request.receipt(), proto::Ack {}).await?;
256 Ok(())
257 }
258
259 async fn open_worktree(
260 self: Arc<Server>,
261 request: TypedEnvelope<proto::OpenWorktree>,
262 ) -> tide::Result<()> {
263 let receipt = request.receipt();
264 let host_user_id = self
265 .state
266 .read()
267 .await
268 .user_id_for_connection(request.sender_id)?;
269
270 let mut collaborator_user_ids = HashSet::new();
271 collaborator_user_ids.insert(host_user_id);
272 for github_login in request.payload.collaborator_logins {
273 match self.app_state.db.create_user(&github_login, false).await {
274 Ok(collaborator_user_id) => {
275 collaborator_user_ids.insert(collaborator_user_id);
276 }
277 Err(err) => {
278 let message = err.to_string();
279 self.peer
280 .respond_with_error(receipt, proto::Error { message })
281 .await?;
282 return Ok(());
283 }
284 }
285 }
286
287 let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
288 let worktree_id = self.state.write().await.add_worktree(Worktree {
289 host_connection_id: request.sender_id,
290 collaborator_user_ids: collaborator_user_ids.clone(),
291 root_name: request.payload.root_name,
292 share: None,
293 });
294
295 self.peer
296 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
297 .await?;
298 self.update_collaborators_for_users(&collaborator_user_ids)
299 .await?;
300
301 Ok(())
302 }
303
304 async fn share_worktree(
305 self: Arc<Server>,
306 mut request: TypedEnvelope<proto::ShareWorktree>,
307 ) -> tide::Result<()> {
308 let worktree = request
309 .payload
310 .worktree
311 .as_mut()
312 .ok_or_else(|| anyhow!("missing worktree"))?;
313 let entries = mem::take(&mut worktree.entries)
314 .into_iter()
315 .map(|entry| (entry.id, entry))
316 .collect();
317
318 let mut state = self.state.write().await;
319 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
320 worktree.share = Some(WorktreeShare {
321 guest_connection_ids: Default::default(),
322 active_replica_ids: Default::default(),
323 entries,
324 });
325 let collaborator_user_ids = worktree.collaborator_user_ids.clone();
326
327 drop(state);
328 self.peer
329 .respond(request.receipt(), proto::ShareWorktreeResponse {})
330 .await?;
331 self.update_collaborators_for_users(&collaborator_user_ids)
332 .await?;
333 } else {
334 self.peer
335 .respond_with_error(
336 request.receipt(),
337 proto::Error {
338 message: "no such worktree".to_string(),
339 },
340 )
341 .await?;
342 }
343 Ok(())
344 }
345
346 async fn unshare_worktree(
347 self: Arc<Server>,
348 request: TypedEnvelope<proto::UnshareWorktree>,
349 ) -> tide::Result<()> {
350 let worktree_id = request.payload.worktree_id;
351
352 let connection_ids;
353 let collaborator_user_ids;
354 {
355 let mut state = self.state.write().await;
356 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
357 if worktree.host_connection_id != request.sender_id {
358 return Err(anyhow!("no such worktree"))?;
359 }
360
361 connection_ids = worktree.connection_ids();
362 collaborator_user_ids = worktree.collaborator_user_ids.clone();
363
364 worktree.share.take();
365 for connection_id in &connection_ids {
366 if let Some(connection) = state.connections.get_mut(connection_id) {
367 connection.worktrees.remove(&worktree_id);
368 }
369 }
370 }
371
372 broadcast(request.sender_id, connection_ids, |conn_id| {
373 self.peer
374 .send(conn_id, proto::UnshareWorktree { worktree_id })
375 })
376 .await?;
377 self.update_collaborators_for_users(&collaborator_user_ids)
378 .await?;
379
380 Ok(())
381 }
382
383 async fn join_worktree(
384 self: Arc<Server>,
385 request: TypedEnvelope<proto::JoinWorktree>,
386 ) -> tide::Result<()> {
387 let worktree_id = request.payload.worktree_id;
388 let user_id = self
389 .state
390 .read()
391 .await
392 .user_id_for_connection(request.sender_id)?;
393
394 let response;
395 let connection_ids;
396 let collaborator_user_ids;
397 let mut state = self.state.write().await;
398 match state.join_worktree(request.sender_id, user_id, worktree_id) {
399 Ok((peer_replica_id, worktree)) => {
400 let share = worktree.share()?;
401 let peer_count = share.guest_connection_ids.len();
402 let mut peers = Vec::with_capacity(peer_count);
403 peers.push(proto::Peer {
404 peer_id: worktree.host_connection_id.0,
405 replica_id: 0,
406 });
407 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
408 if *peer_conn_id != request.sender_id {
409 peers.push(proto::Peer {
410 peer_id: peer_conn_id.0,
411 replica_id: *peer_replica_id as u32,
412 });
413 }
414 }
415 response = proto::JoinWorktreeResponse {
416 worktree: Some(proto::Worktree {
417 id: worktree_id,
418 root_name: worktree.root_name.clone(),
419 entries: share.entries.values().cloned().collect(),
420 }),
421 replica_id: peer_replica_id as u32,
422 peers,
423 };
424 connection_ids = worktree.connection_ids();
425 collaborator_user_ids = worktree.collaborator_user_ids.clone();
426 }
427 Err(error) => {
428 self.peer
429 .respond_with_error(
430 request.receipt(),
431 proto::Error {
432 message: error.to_string(),
433 },
434 )
435 .await?;
436 return Ok(());
437 }
438 }
439
440 drop(state);
441 broadcast(request.sender_id, connection_ids, |conn_id| {
442 self.peer.send(
443 conn_id,
444 proto::AddPeer {
445 worktree_id,
446 peer: Some(proto::Peer {
447 peer_id: request.sender_id.0,
448 replica_id: response.replica_id,
449 }),
450 },
451 )
452 })
453 .await?;
454 self.peer.respond(request.receipt(), response).await?;
455 self.update_collaborators_for_users(&collaborator_user_ids)
456 .await?;
457
458 Ok(())
459 }
460
461 async fn handle_close_worktree(
462 self: Arc<Server>,
463 request: TypedEnvelope<proto::CloseWorktree>,
464 ) -> tide::Result<()> {
465 self.close_worktree(request.payload.worktree_id, request.sender_id)
466 .await
467 }
468
469 async fn close_worktree(
470 self: &Arc<Server>,
471 worktree_id: u64,
472 sender_conn_id: ConnectionId,
473 ) -> tide::Result<()> {
474 let connection_ids;
475 let collaborator_user_ids;
476 let mut is_host = false;
477 let mut is_guest = false;
478 {
479 let mut state = self.state.write().await;
480 let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
481 connection_ids = worktree.connection_ids();
482 collaborator_user_ids = worktree.collaborator_user_ids.clone();
483
484 if worktree.host_connection_id == sender_conn_id {
485 is_host = true;
486 state.remove_worktree(worktree_id);
487 } else {
488 let share = worktree.share_mut()?;
489 if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
490 is_guest = true;
491 share.active_replica_ids.remove(&replica_id);
492 }
493 }
494 }
495
496 if is_host {
497 broadcast(sender_conn_id, connection_ids, |conn_id| {
498 self.peer
499 .send(conn_id, proto::UnshareWorktree { worktree_id })
500 })
501 .await?;
502 } else if is_guest {
503 broadcast(sender_conn_id, connection_ids, |conn_id| {
504 self.peer.send(
505 conn_id,
506 proto::RemovePeer {
507 worktree_id,
508 peer_id: sender_conn_id.0,
509 },
510 )
511 })
512 .await?
513 }
514 self.update_collaborators_for_users(&collaborator_user_ids)
515 .await?;
516 Ok(())
517 }
518
519 async fn update_worktree(
520 self: Arc<Server>,
521 request: TypedEnvelope<proto::UpdateWorktree>,
522 ) -> tide::Result<()> {
523 {
524 let mut state = self.state.write().await;
525 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
526 let share = worktree.share_mut()?;
527
528 for entry_id in &request.payload.removed_entries {
529 share.entries.remove(&entry_id);
530 }
531
532 for entry in &request.payload.updated_entries {
533 share.entries.insert(entry.id, entry.clone());
534 }
535 }
536
537 self.broadcast_in_worktree(request.payload.worktree_id, &request)
538 .await?;
539 Ok(())
540 }
541
542 async fn open_buffer(
543 self: Arc<Server>,
544 request: TypedEnvelope<proto::OpenBuffer>,
545 ) -> tide::Result<()> {
546 let receipt = request.receipt();
547 let worktree_id = request.payload.worktree_id;
548 let host_connection_id = self
549 .state
550 .read()
551 .await
552 .read_worktree(worktree_id, request.sender_id)?
553 .host_connection_id;
554
555 let response = self
556 .peer
557 .forward_request(request.sender_id, host_connection_id, request.payload)
558 .await?;
559 self.peer.respond(receipt, response).await?;
560 Ok(())
561 }
562
563 async fn close_buffer(
564 self: Arc<Server>,
565 request: TypedEnvelope<proto::CloseBuffer>,
566 ) -> tide::Result<()> {
567 let host_connection_id = self
568 .state
569 .read()
570 .await
571 .read_worktree(request.payload.worktree_id, request.sender_id)?
572 .host_connection_id;
573
574 self.peer
575 .forward_send(request.sender_id, host_connection_id, request.payload)
576 .await?;
577
578 Ok(())
579 }
580
581 async fn save_buffer(
582 self: Arc<Server>,
583 request: TypedEnvelope<proto::SaveBuffer>,
584 ) -> tide::Result<()> {
585 let host;
586 let guests;
587 {
588 let state = self.state.read().await;
589 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
590 host = worktree.host_connection_id;
591 guests = worktree
592 .share()?
593 .guest_connection_ids
594 .keys()
595 .copied()
596 .collect::<Vec<_>>();
597 }
598
599 let sender = request.sender_id;
600 let receipt = request.receipt();
601 let response = self
602 .peer
603 .forward_request(sender, host, request.payload.clone())
604 .await?;
605
606 broadcast(host, guests, |conn_id| {
607 let response = response.clone();
608 let peer = &self.peer;
609 async move {
610 if conn_id == sender {
611 peer.respond(receipt, response).await
612 } else {
613 peer.forward_send(host, conn_id, response).await
614 }
615 }
616 })
617 .await?;
618
619 Ok(())
620 }
621
622 async fn update_buffer(
623 self: Arc<Server>,
624 request: TypedEnvelope<proto::UpdateBuffer>,
625 ) -> tide::Result<()> {
626 self.broadcast_in_worktree(request.payload.worktree_id, &request)
627 .await?;
628 self.peer.respond(request.receipt(), proto::Ack {}).await?;
629 Ok(())
630 }
631
632 async fn buffer_saved(
633 self: Arc<Server>,
634 request: TypedEnvelope<proto::BufferSaved>,
635 ) -> tide::Result<()> {
636 self.broadcast_in_worktree(request.payload.worktree_id, &request)
637 .await
638 }
639
640 async fn get_channels(
641 self: Arc<Server>,
642 request: TypedEnvelope<proto::GetChannels>,
643 ) -> tide::Result<()> {
644 let user_id = self
645 .state
646 .read()
647 .await
648 .user_id_for_connection(request.sender_id)?;
649 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
650 self.peer
651 .respond(
652 request.receipt(),
653 proto::GetChannelsResponse {
654 channels: channels
655 .into_iter()
656 .map(|chan| proto::Channel {
657 id: chan.id.to_proto(),
658 name: chan.name,
659 })
660 .collect(),
661 },
662 )
663 .await?;
664 Ok(())
665 }
666
667 async fn get_users(
668 self: Arc<Server>,
669 request: TypedEnvelope<proto::GetUsers>,
670 ) -> tide::Result<()> {
671 let user_id = self
672 .state
673 .read()
674 .await
675 .user_id_for_connection(request.sender_id)?;
676 let receipt = request.receipt();
677 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
678 let users = self
679 .app_state
680 .db
681 .get_users_by_ids(user_id, user_ids)
682 .await?
683 .into_iter()
684 .map(|user| proto::User {
685 id: user.id.to_proto(),
686 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
687 github_login: user.github_login,
688 })
689 .collect();
690 self.peer
691 .respond(receipt, proto::GetUsersResponse { users })
692 .await?;
693 Ok(())
694 }
695
696 async fn update_collaborators_for_users<'a>(
697 self: &Arc<Server>,
698 user_ids: impl IntoIterator<Item = &'a UserId>,
699 ) -> tide::Result<()> {
700 let mut send_futures = Vec::new();
701
702 let state = self.state.read().await;
703 for user_id in user_ids {
704 let mut collaborators = HashMap::new();
705 for worktree_id in state
706 .visible_worktrees_by_user_id
707 .get(&user_id)
708 .unwrap_or(&HashSet::new())
709 {
710 let worktree = &state.worktrees[worktree_id];
711
712 let mut participants = HashSet::new();
713 if let Ok(share) = worktree.share() {
714 for guest_connection_id in share.guest_connection_ids.keys() {
715 let user_id = state.user_id_for_connection(*guest_connection_id)?;
716 participants.insert(user_id.to_proto());
717 }
718 }
719
720 let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
721 let host =
722 collaborators
723 .entry(host_user_id)
724 .or_insert_with(|| proto::Collaborator {
725 user_id: host_user_id.to_proto(),
726 worktrees: Vec::new(),
727 });
728 host.worktrees.push(proto::WorktreeMetadata {
729 root_name: worktree.root_name.clone(),
730 is_shared: worktree.share().is_ok(),
731 participants: participants.into_iter().collect(),
732 });
733 }
734
735 let collaborators = collaborators.into_values().collect::<Vec<_>>();
736 for connection_id in state.user_connection_ids(*user_id) {
737 send_futures.push(self.peer.send(
738 connection_id,
739 proto::UpdateCollaborators {
740 collaborators: collaborators.clone(),
741 },
742 ));
743 }
744 }
745
746 drop(state);
747 futures::future::try_join_all(send_futures).await?;
748
749 Ok(())
750 }
751
752 async fn join_channel(
753 self: Arc<Self>,
754 request: TypedEnvelope<proto::JoinChannel>,
755 ) -> tide::Result<()> {
756 let user_id = self
757 .state
758 .read()
759 .await
760 .user_id_for_connection(request.sender_id)?;
761 let channel_id = ChannelId::from_proto(request.payload.channel_id);
762 if !self
763 .app_state
764 .db
765 .can_user_access_channel(user_id, channel_id)
766 .await?
767 {
768 Err(anyhow!("access denied"))?;
769 }
770
771 self.state
772 .write()
773 .await
774 .join_channel(request.sender_id, channel_id);
775 let messages = self
776 .app_state
777 .db
778 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
779 .await?
780 .into_iter()
781 .map(|msg| proto::ChannelMessage {
782 id: msg.id.to_proto(),
783 body: msg.body,
784 timestamp: msg.sent_at.unix_timestamp() as u64,
785 sender_id: msg.sender_id.to_proto(),
786 nonce: Some(msg.nonce.as_u128().into()),
787 })
788 .collect::<Vec<_>>();
789 self.peer
790 .respond(
791 request.receipt(),
792 proto::JoinChannelResponse {
793 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
794 messages,
795 },
796 )
797 .await?;
798 Ok(())
799 }
800
801 async fn leave_channel(
802 self: Arc<Self>,
803 request: TypedEnvelope<proto::LeaveChannel>,
804 ) -> tide::Result<()> {
805 let user_id = self
806 .state
807 .read()
808 .await
809 .user_id_for_connection(request.sender_id)?;
810 let channel_id = ChannelId::from_proto(request.payload.channel_id);
811 if !self
812 .app_state
813 .db
814 .can_user_access_channel(user_id, channel_id)
815 .await?
816 {
817 Err(anyhow!("access denied"))?;
818 }
819
820 self.state
821 .write()
822 .await
823 .leave_channel(request.sender_id, channel_id);
824
825 Ok(())
826 }
827
828 async fn send_channel_message(
829 self: Arc<Self>,
830 request: TypedEnvelope<proto::SendChannelMessage>,
831 ) -> tide::Result<()> {
832 let receipt = request.receipt();
833 let channel_id = ChannelId::from_proto(request.payload.channel_id);
834 let user_id;
835 let connection_ids;
836 {
837 let state = self.state.read().await;
838 user_id = state.user_id_for_connection(request.sender_id)?;
839 if let Some(channel) = state.channels.get(&channel_id) {
840 connection_ids = channel.connection_ids();
841 } else {
842 return Ok(());
843 }
844 }
845
846 // Validate the message body.
847 let body = request.payload.body.trim().to_string();
848 if body.len() > MAX_MESSAGE_LEN {
849 self.peer
850 .respond_with_error(
851 receipt,
852 proto::Error {
853 message: "message is too long".to_string(),
854 },
855 )
856 .await?;
857 return Ok(());
858 }
859 if body.is_empty() {
860 self.peer
861 .respond_with_error(
862 receipt,
863 proto::Error {
864 message: "message can't be blank".to_string(),
865 },
866 )
867 .await?;
868 return Ok(());
869 }
870
871 let timestamp = OffsetDateTime::now_utc();
872 let nonce = if let Some(nonce) = request.payload.nonce {
873 nonce
874 } else {
875 self.peer
876 .respond_with_error(
877 receipt,
878 proto::Error {
879 message: "nonce can't be blank".to_string(),
880 },
881 )
882 .await?;
883 return Ok(());
884 };
885
886 let message_id = self
887 .app_state
888 .db
889 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
890 .await?
891 .to_proto();
892 let message = proto::ChannelMessage {
893 sender_id: user_id.to_proto(),
894 id: message_id,
895 body,
896 timestamp: timestamp.unix_timestamp() as u64,
897 nonce: Some(nonce),
898 };
899 broadcast(request.sender_id, connection_ids, |conn_id| {
900 self.peer.send(
901 conn_id,
902 proto::ChannelMessageSent {
903 channel_id: channel_id.to_proto(),
904 message: Some(message.clone()),
905 },
906 )
907 })
908 .await?;
909 self.peer
910 .respond(
911 receipt,
912 proto::SendChannelMessageResponse {
913 message: Some(message),
914 },
915 )
916 .await?;
917 Ok(())
918 }
919
920 async fn get_channel_messages(
921 self: Arc<Self>,
922 request: TypedEnvelope<proto::GetChannelMessages>,
923 ) -> tide::Result<()> {
924 let user_id = self
925 .state
926 .read()
927 .await
928 .user_id_for_connection(request.sender_id)?;
929 let channel_id = ChannelId::from_proto(request.payload.channel_id);
930 if !self
931 .app_state
932 .db
933 .can_user_access_channel(user_id, channel_id)
934 .await?
935 {
936 Err(anyhow!("access denied"))?;
937 }
938
939 let messages = self
940 .app_state
941 .db
942 .get_channel_messages(
943 channel_id,
944 MESSAGE_COUNT_PER_PAGE,
945 Some(MessageId::from_proto(request.payload.before_message_id)),
946 )
947 .await?
948 .into_iter()
949 .map(|msg| proto::ChannelMessage {
950 id: msg.id.to_proto(),
951 body: msg.body,
952 timestamp: msg.sent_at.unix_timestamp() as u64,
953 sender_id: msg.sender_id.to_proto(),
954 nonce: Some(msg.nonce.as_u128().into()),
955 })
956 .collect::<Vec<_>>();
957 self.peer
958 .respond(
959 request.receipt(),
960 proto::GetChannelMessagesResponse {
961 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
962 messages,
963 },
964 )
965 .await?;
966 Ok(())
967 }
968
969 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
970 &self,
971 worktree_id: u64,
972 message: &TypedEnvelope<T>,
973 ) -> tide::Result<()> {
974 let connection_ids = self
975 .state
976 .read()
977 .await
978 .read_worktree(worktree_id, message.sender_id)?
979 .connection_ids();
980
981 broadcast(message.sender_id, connection_ids, |conn_id| {
982 self.peer
983 .forward_send(message.sender_id, conn_id, message.payload.clone())
984 })
985 .await?;
986
987 Ok(())
988 }
989}
990
991pub async fn broadcast<F, T>(
992 sender_id: ConnectionId,
993 receiver_ids: Vec<ConnectionId>,
994 mut f: F,
995) -> anyhow::Result<()>
996where
997 F: FnMut(ConnectionId) -> T,
998 T: Future<Output = anyhow::Result<()>>,
999{
1000 let futures = receiver_ids
1001 .into_iter()
1002 .filter(|id| *id != sender_id)
1003 .map(|id| f(id));
1004 futures::future::try_join_all(futures).await?;
1005 Ok(())
1006}
1007
1008impl ServerState {
1009 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1010 if let Some(connection) = self.connections.get_mut(&connection_id) {
1011 connection.channels.insert(channel_id);
1012 self.channels
1013 .entry(channel_id)
1014 .or_default()
1015 .connection_ids
1016 .insert(connection_id);
1017 }
1018 }
1019
1020 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1021 if let Some(connection) = self.connections.get_mut(&connection_id) {
1022 connection.channels.remove(&channel_id);
1023 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1024 entry.get_mut().connection_ids.remove(&connection_id);
1025 if entry.get_mut().connection_ids.is_empty() {
1026 entry.remove();
1027 }
1028 }
1029 }
1030 }
1031
1032 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1033 Ok(self
1034 .connections
1035 .get(&connection_id)
1036 .ok_or_else(|| anyhow!("unknown connection"))?
1037 .user_id)
1038 }
1039
1040 fn user_connection_ids<'a>(
1041 &'a self,
1042 user_id: UserId,
1043 ) -> impl 'a + Iterator<Item = ConnectionId> {
1044 self.connections_by_user_id
1045 .get(&user_id)
1046 .into_iter()
1047 .flatten()
1048 .copied()
1049 }
1050
1051 // Add the given connection as a guest of the given worktree
1052 fn join_worktree(
1053 &mut self,
1054 connection_id: ConnectionId,
1055 user_id: UserId,
1056 worktree_id: u64,
1057 ) -> tide::Result<(ReplicaId, &Worktree)> {
1058 let connection = self
1059 .connections
1060 .get_mut(&connection_id)
1061 .ok_or_else(|| anyhow!("no such connection"))?;
1062 let worktree = self
1063 .worktrees
1064 .get_mut(&worktree_id)
1065 .ok_or_else(|| anyhow!("no such worktree"))?;
1066 if !worktree.collaborator_user_ids.contains(&user_id) {
1067 Err(anyhow!("no such worktree"))?;
1068 }
1069
1070 let share = worktree.share_mut()?;
1071 connection.worktrees.insert(worktree_id);
1072
1073 let mut replica_id = 1;
1074 while share.active_replica_ids.contains(&replica_id) {
1075 replica_id += 1;
1076 }
1077 share.active_replica_ids.insert(replica_id);
1078 share.guest_connection_ids.insert(connection_id, replica_id);
1079 return Ok((replica_id, worktree));
1080 }
1081
1082 fn read_worktree(
1083 &self,
1084 worktree_id: u64,
1085 connection_id: ConnectionId,
1086 ) -> tide::Result<&Worktree> {
1087 let worktree = self
1088 .worktrees
1089 .get(&worktree_id)
1090 .ok_or_else(|| anyhow!("worktree not found"))?;
1091
1092 if worktree.host_connection_id == connection_id
1093 || worktree
1094 .share()?
1095 .guest_connection_ids
1096 .contains_key(&connection_id)
1097 {
1098 Ok(worktree)
1099 } else {
1100 Err(anyhow!(
1101 "{} is not a member of worktree {}",
1102 connection_id,
1103 worktree_id
1104 ))?
1105 }
1106 }
1107
1108 fn write_worktree(
1109 &mut self,
1110 worktree_id: u64,
1111 connection_id: ConnectionId,
1112 ) -> tide::Result<&mut Worktree> {
1113 let worktree = self
1114 .worktrees
1115 .get_mut(&worktree_id)
1116 .ok_or_else(|| anyhow!("worktree not found"))?;
1117
1118 if worktree.host_connection_id == connection_id
1119 || worktree.share.as_ref().map_or(false, |share| {
1120 share.guest_connection_ids.contains_key(&connection_id)
1121 })
1122 {
1123 Ok(worktree)
1124 } else {
1125 Err(anyhow!(
1126 "{} is not a member of worktree {}",
1127 connection_id,
1128 worktree_id
1129 ))?
1130 }
1131 }
1132
1133 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1134 let worktree_id = self.next_worktree_id;
1135 for collaborator_user_id in &worktree.collaborator_user_ids {
1136 self.visible_worktrees_by_user_id
1137 .entry(*collaborator_user_id)
1138 .or_default()
1139 .insert(worktree_id);
1140 }
1141 self.next_worktree_id += 1;
1142 self.worktrees.insert(worktree_id, worktree);
1143 worktree_id
1144 }
1145
1146 fn remove_worktree(&mut self, worktree_id: u64) {
1147 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1148 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1149 connection.worktrees.remove(&worktree_id);
1150 }
1151 if let Some(share) = worktree.share {
1152 for connection_id in share.guest_connection_ids.keys() {
1153 if let Some(connection) = self.connections.get_mut(connection_id) {
1154 connection.worktrees.remove(&worktree_id);
1155 }
1156 }
1157 }
1158 for collaborator_user_id in worktree.collaborator_user_ids {
1159 if let Some(visible_worktrees) = self
1160 .visible_worktrees_by_user_id
1161 .get_mut(&collaborator_user_id)
1162 {
1163 visible_worktrees.remove(&worktree_id);
1164 }
1165 }
1166 }
1167}
1168
1169impl Worktree {
1170 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1171 if let Some(share) = &self.share {
1172 share
1173 .guest_connection_ids
1174 .keys()
1175 .copied()
1176 .chain(Some(self.host_connection_id))
1177 .collect()
1178 } else {
1179 vec![self.host_connection_id]
1180 }
1181 }
1182
1183 fn share(&self) -> tide::Result<&WorktreeShare> {
1184 Ok(self
1185 .share
1186 .as_ref()
1187 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1188 }
1189
1190 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1191 Ok(self
1192 .share
1193 .as_mut()
1194 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1195 }
1196}
1197
1198impl Channel {
1199 fn connection_ids(&self) -> Vec<ConnectionId> {
1200 self.connection_ids.iter().copied().collect()
1201 }
1202}
1203
1204pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1205 let server = Server::new(app.state().clone(), rpc.clone(), None);
1206 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1207 let user_id = request.ext::<UserId>().copied();
1208 let server = server.clone();
1209 async move {
1210 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1211
1212 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1213 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1214 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1215
1216 if !upgrade_requested {
1217 return Ok(Response::new(StatusCode::UpgradeRequired));
1218 }
1219
1220 let header = match request.header("Sec-Websocket-Key") {
1221 Some(h) => h.as_str(),
1222 None => return Err(anyhow!("expected sec-websocket-key"))?,
1223 };
1224
1225 let mut response = Response::new(StatusCode::SwitchingProtocols);
1226 response.insert_header(UPGRADE, "websocket");
1227 response.insert_header(CONNECTION, "Upgrade");
1228 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1229 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1230 response.insert_header("Sec-Websocket-Version", "13");
1231
1232 let http_res: &mut tide::http::Response = response.as_mut();
1233 let upgrade_receiver = http_res.recv_upgrade().await;
1234 let addr = request.remote().unwrap_or("unknown").to_string();
1235 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1236 task::spawn(async move {
1237 if let Some(stream) = upgrade_receiver.await {
1238 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1239 }
1240 });
1241
1242 Ok(response)
1243 }
1244 });
1245}
1246
1247fn header_contains_ignore_case<T>(
1248 request: &tide::Request<T>,
1249 header_name: HeaderName,
1250 value: &str,
1251) -> bool {
1252 request
1253 .header(header_name)
1254 .map(|h| {
1255 h.as_str()
1256 .split(',')
1257 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1258 })
1259 .unwrap_or(false)
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264 use super::*;
1265 use crate::{
1266 auth,
1267 db::{tests::TestDb, UserId},
1268 github, AppState, Config,
1269 };
1270 use async_std::{sync::RwLockReadGuard, task};
1271 use gpui::{ModelHandle, TestAppContext};
1272 use parking_lot::Mutex;
1273 use postage::{mpsc, watch};
1274 use serde_json::json;
1275 use sqlx::types::time::OffsetDateTime;
1276 use std::{
1277 path::Path,
1278 sync::{
1279 atomic::{AtomicBool, Ordering::SeqCst},
1280 Arc,
1281 },
1282 time::Duration,
1283 };
1284 use zed::{
1285 channel::{Channel, ChannelDetails, ChannelList},
1286 editor::{Editor, Insert},
1287 fs::{FakeFs, Fs as _},
1288 language::LanguageRegistry,
1289 rpc::{self, Client, Credentials, EstablishConnectionError},
1290 settings,
1291 test::FakeHttpClient,
1292 user::UserStore,
1293 worktree::Worktree,
1294 };
1295 use zrpc::Peer;
1296
1297 #[gpui::test]
1298 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1299 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1300 let settings = cx_b.read(settings::test).1;
1301 let lang_registry = Arc::new(LanguageRegistry::new());
1302
1303 // Connect to a server as 2 clients.
1304 let mut server = TestServer::start().await;
1305 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1306 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1307
1308 cx_a.foreground().forbid_parking();
1309
1310 // Share a local worktree as client A
1311 let fs = Arc::new(FakeFs::new());
1312 fs.insert_tree(
1313 "/a",
1314 json!({
1315 ".zed.toml": r#"collaborators = ["user_b"]"#,
1316 "a.txt": "a-contents",
1317 "b.txt": "b-contents",
1318 }),
1319 )
1320 .await;
1321 let worktree_a = Worktree::open_local(
1322 client_a.clone(),
1323 "/a".as_ref(),
1324 fs,
1325 lang_registry.clone(),
1326 &mut cx_a.to_async(),
1327 )
1328 .await
1329 .unwrap();
1330 worktree_a
1331 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1332 .await;
1333 let worktree_id = worktree_a
1334 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1335 .await
1336 .unwrap();
1337
1338 // Join that worktree as client B, and see that a guest has joined as client A.
1339 let worktree_b = Worktree::open_remote(
1340 client_b.clone(),
1341 worktree_id,
1342 lang_registry.clone(),
1343 &mut cx_b.to_async(),
1344 )
1345 .await
1346 .unwrap();
1347 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1348 worktree_a
1349 .condition(&cx_a, |tree, _| {
1350 tree.peers()
1351 .values()
1352 .any(|replica_id| *replica_id == replica_id_b)
1353 })
1354 .await;
1355
1356 // Open the same file as client B and client A.
1357 let buffer_b = worktree_b
1358 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1359 .await
1360 .unwrap();
1361 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1362 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1363 let buffer_a = worktree_a
1364 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1365 .await
1366 .unwrap();
1367
1368 // Create a selection set as client B and see that selection set as client A.
1369 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1370 buffer_a
1371 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1372 .await;
1373
1374 // Edit the buffer as client B and see that edit as client A.
1375 editor_b.update(&mut cx_b, |editor, cx| {
1376 editor.insert(&Insert("ok, ".into()), cx)
1377 });
1378 buffer_a
1379 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1380 .await;
1381
1382 // Remove the selection set as client B, see those selections disappear as client A.
1383 cx_b.update(move |_| drop(editor_b));
1384 buffer_a
1385 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1386 .await;
1387
1388 // Close the buffer as client A, see that the buffer is closed.
1389 cx_a.update(move |_| drop(buffer_a));
1390 worktree_a
1391 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1392 .await;
1393
1394 // Dropping the worktree removes client B from client A's peers.
1395 cx_b.update(move |_| drop(worktree_b));
1396 worktree_a
1397 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1398 .await;
1399 }
1400
1401 #[gpui::test]
1402 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1403 mut cx_a: TestAppContext,
1404 mut cx_b: TestAppContext,
1405 mut cx_c: TestAppContext,
1406 ) {
1407 cx_a.foreground().forbid_parking();
1408 let lang_registry = Arc::new(LanguageRegistry::new());
1409
1410 // Connect to a server as 3 clients.
1411 let mut server = TestServer::start().await;
1412 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1413 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1414 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1415
1416 let fs = Arc::new(FakeFs::new());
1417
1418 // Share a worktree as client A.
1419 fs.insert_tree(
1420 "/a",
1421 json!({
1422 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1423 "file1": "",
1424 "file2": ""
1425 }),
1426 )
1427 .await;
1428
1429 let worktree_a = Worktree::open_local(
1430 client_a.clone(),
1431 "/a".as_ref(),
1432 fs.clone(),
1433 lang_registry.clone(),
1434 &mut cx_a.to_async(),
1435 )
1436 .await
1437 .unwrap();
1438 worktree_a
1439 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1440 .await;
1441 let worktree_id = worktree_a
1442 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1443 .await
1444 .unwrap();
1445
1446 // Join that worktree as clients B and C.
1447 let worktree_b = Worktree::open_remote(
1448 client_b.clone(),
1449 worktree_id,
1450 lang_registry.clone(),
1451 &mut cx_b.to_async(),
1452 )
1453 .await
1454 .unwrap();
1455 let worktree_c = Worktree::open_remote(
1456 client_c.clone(),
1457 worktree_id,
1458 lang_registry.clone(),
1459 &mut cx_c.to_async(),
1460 )
1461 .await
1462 .unwrap();
1463
1464 // Open and edit a buffer as both guests B and C.
1465 let buffer_b = worktree_b
1466 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1467 .await
1468 .unwrap();
1469 let buffer_c = worktree_c
1470 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1471 .await
1472 .unwrap();
1473 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1474 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1475
1476 // Open and edit that buffer as the host.
1477 let buffer_a = worktree_a
1478 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1479 .await
1480 .unwrap();
1481
1482 buffer_a
1483 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1484 .await;
1485 buffer_a.update(&mut cx_a, |buf, cx| {
1486 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1487 });
1488
1489 // Wait for edits to propagate
1490 buffer_a
1491 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1492 .await;
1493 buffer_b
1494 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1495 .await;
1496 buffer_c
1497 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1498 .await;
1499
1500 // Edit the buffer as the host and concurrently save as guest B.
1501 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1502 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1503 save_b.await.unwrap();
1504 assert_eq!(
1505 fs.load("/a/file1".as_ref()).await.unwrap(),
1506 "hi-a, i-am-c, i-am-b, i-am-a"
1507 );
1508 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1509 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1510 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1511
1512 // Make changes on host's file system, see those changes on the guests.
1513 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1514 .await
1515 .unwrap();
1516 fs.insert_file(Path::new("/a/file4"), "4".into())
1517 .await
1518 .unwrap();
1519
1520 worktree_b
1521 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1522 .await;
1523 worktree_c
1524 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1525 .await;
1526 worktree_b.read_with(&cx_b, |tree, _| {
1527 assert_eq!(
1528 tree.paths()
1529 .map(|p| p.to_string_lossy())
1530 .collect::<Vec<_>>(),
1531 &[".zed.toml", "file1", "file3", "file4"]
1532 )
1533 });
1534 worktree_c.read_with(&cx_c, |tree, _| {
1535 assert_eq!(
1536 tree.paths()
1537 .map(|p| p.to_string_lossy())
1538 .collect::<Vec<_>>(),
1539 &[".zed.toml", "file1", "file3", "file4"]
1540 )
1541 });
1542 }
1543
1544 #[gpui::test]
1545 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1546 cx_a.foreground().forbid_parking();
1547 let lang_registry = Arc::new(LanguageRegistry::new());
1548
1549 // Connect to a server as 2 clients.
1550 let mut server = TestServer::start().await;
1551 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1552 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1553
1554 // Share a local worktree as client A
1555 let fs = Arc::new(FakeFs::new());
1556 fs.insert_tree(
1557 "/dir",
1558 json!({
1559 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1560 "a.txt": "a-contents",
1561 }),
1562 )
1563 .await;
1564
1565 let worktree_a = Worktree::open_local(
1566 client_a.clone(),
1567 "/dir".as_ref(),
1568 fs,
1569 lang_registry.clone(),
1570 &mut cx_a.to_async(),
1571 )
1572 .await
1573 .unwrap();
1574 worktree_a
1575 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1576 .await;
1577 let worktree_id = worktree_a
1578 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1579 .await
1580 .unwrap();
1581
1582 // Join that worktree as client B, and see that a guest has joined as client A.
1583 let worktree_b = Worktree::open_remote(
1584 client_b.clone(),
1585 worktree_id,
1586 lang_registry.clone(),
1587 &mut cx_b.to_async(),
1588 )
1589 .await
1590 .unwrap();
1591
1592 let buffer_b = worktree_b
1593 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1594 .await
1595 .unwrap();
1596 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1597
1598 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1599 buffer_b.read_with(&cx_b, |buf, _| {
1600 assert!(buf.is_dirty());
1601 assert!(!buf.has_conflict());
1602 });
1603
1604 buffer_b
1605 .update(&mut cx_b, |buf, cx| buf.save(cx))
1606 .unwrap()
1607 .await
1608 .unwrap();
1609 worktree_b
1610 .condition(&cx_b, |_, cx| {
1611 buffer_b.read(cx).file().unwrap().mtime != mtime
1612 })
1613 .await;
1614 buffer_b.read_with(&cx_b, |buf, _| {
1615 assert!(!buf.is_dirty());
1616 assert!(!buf.has_conflict());
1617 });
1618
1619 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1620 buffer_b.read_with(&cx_b, |buf, _| {
1621 assert!(buf.is_dirty());
1622 assert!(!buf.has_conflict());
1623 });
1624 }
1625
1626 #[gpui::test]
1627 async fn test_editing_while_guest_opens_buffer(
1628 mut cx_a: TestAppContext,
1629 mut cx_b: TestAppContext,
1630 ) {
1631 cx_a.foreground().forbid_parking();
1632 let lang_registry = Arc::new(LanguageRegistry::new());
1633
1634 // Connect to a server as 2 clients.
1635 let mut server = TestServer::start().await;
1636 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1637 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1638
1639 // Share a local worktree as client A
1640 let fs = Arc::new(FakeFs::new());
1641 fs.insert_tree(
1642 "/dir",
1643 json!({
1644 ".zed.toml": r#"collaborators = ["user_b"]"#,
1645 "a.txt": "a-contents",
1646 }),
1647 )
1648 .await;
1649 let worktree_a = Worktree::open_local(
1650 client_a.clone(),
1651 "/dir".as_ref(),
1652 fs,
1653 lang_registry.clone(),
1654 &mut cx_a.to_async(),
1655 )
1656 .await
1657 .unwrap();
1658 worktree_a
1659 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1660 .await;
1661 let worktree_id = worktree_a
1662 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1663 .await
1664 .unwrap();
1665
1666 // Join that worktree as client B, and see that a guest has joined as client A.
1667 let worktree_b = Worktree::open_remote(
1668 client_b.clone(),
1669 worktree_id,
1670 lang_registry.clone(),
1671 &mut cx_b.to_async(),
1672 )
1673 .await
1674 .unwrap();
1675
1676 let buffer_a = worktree_a
1677 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1678 .await
1679 .unwrap();
1680 let buffer_b = cx_b
1681 .background()
1682 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1683
1684 task::yield_now().await;
1685 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1686
1687 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1688 let buffer_b = buffer_b.await.unwrap();
1689 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1690 }
1691
1692 #[gpui::test]
1693 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1694 cx_a.foreground().forbid_parking();
1695 let lang_registry = Arc::new(LanguageRegistry::new());
1696
1697 // Connect to a server as 2 clients.
1698 let mut server = TestServer::start().await;
1699 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1700 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1701
1702 // Share a local worktree as client A
1703 let fs = Arc::new(FakeFs::new());
1704 fs.insert_tree(
1705 "/a",
1706 json!({
1707 ".zed.toml": r#"collaborators = ["user_b"]"#,
1708 "a.txt": "a-contents",
1709 "b.txt": "b-contents",
1710 }),
1711 )
1712 .await;
1713 let worktree_a = Worktree::open_local(
1714 client_a.clone(),
1715 "/a".as_ref(),
1716 fs,
1717 lang_registry.clone(),
1718 &mut cx_a.to_async(),
1719 )
1720 .await
1721 .unwrap();
1722 worktree_a
1723 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1724 .await;
1725 let worktree_id = worktree_a
1726 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1727 .await
1728 .unwrap();
1729
1730 // Join that worktree as client B, and see that a guest has joined as client A.
1731 let _worktree_b = Worktree::open_remote(
1732 client_b.clone(),
1733 worktree_id,
1734 lang_registry.clone(),
1735 &mut cx_b.to_async(),
1736 )
1737 .await
1738 .unwrap();
1739 worktree_a
1740 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1741 .await;
1742
1743 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1744 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1745 worktree_a
1746 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1747 .await;
1748 }
1749
1750 #[gpui::test]
1751 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1752 cx_a.foreground().forbid_parking();
1753
1754 // Connect to a server as 2 clients.
1755 let mut server = TestServer::start().await;
1756 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1757 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1758
1759 // Create an org that includes these 2 users.
1760 let db = &server.app_state.db;
1761 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1762 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1763 .await
1764 .unwrap();
1765 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1766 .await
1767 .unwrap();
1768
1769 // Create a channel that includes all the users.
1770 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1771 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1772 .await
1773 .unwrap();
1774 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1775 .await
1776 .unwrap();
1777 db.create_channel_message(
1778 channel_id,
1779 current_user_id(&user_store_b, &cx_b),
1780 "hello A, it's B.",
1781 OffsetDateTime::now_utc(),
1782 1,
1783 )
1784 .await
1785 .unwrap();
1786
1787 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1788 channels_a
1789 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1790 .await;
1791 channels_a.read_with(&cx_a, |list, _| {
1792 assert_eq!(
1793 list.available_channels().unwrap(),
1794 &[ChannelDetails {
1795 id: channel_id.to_proto(),
1796 name: "test-channel".to_string()
1797 }]
1798 )
1799 });
1800 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1801 this.get_channel(channel_id.to_proto(), cx).unwrap()
1802 });
1803 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1804 channel_a
1805 .condition(&cx_a, |channel, _| {
1806 channel_messages(channel)
1807 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1808 })
1809 .await;
1810
1811 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1812 channels_b
1813 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1814 .await;
1815 channels_b.read_with(&cx_b, |list, _| {
1816 assert_eq!(
1817 list.available_channels().unwrap(),
1818 &[ChannelDetails {
1819 id: channel_id.to_proto(),
1820 name: "test-channel".to_string()
1821 }]
1822 )
1823 });
1824
1825 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1826 this.get_channel(channel_id.to_proto(), cx).unwrap()
1827 });
1828 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1829 channel_b
1830 .condition(&cx_b, |channel, _| {
1831 channel_messages(channel)
1832 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1833 })
1834 .await;
1835
1836 channel_a
1837 .update(&mut cx_a, |channel, cx| {
1838 channel
1839 .send_message("oh, hi B.".to_string(), cx)
1840 .unwrap()
1841 .detach();
1842 let task = channel.send_message("sup".to_string(), cx).unwrap();
1843 assert_eq!(
1844 channel_messages(channel),
1845 &[
1846 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1847 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1848 ("user_a".to_string(), "sup".to_string(), true)
1849 ]
1850 );
1851 task
1852 })
1853 .await
1854 .unwrap();
1855
1856 channel_b
1857 .condition(&cx_b, |channel, _| {
1858 channel_messages(channel)
1859 == [
1860 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1861 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1862 ("user_a".to_string(), "sup".to_string(), false),
1863 ]
1864 })
1865 .await;
1866
1867 assert_eq!(
1868 server.state().await.channels[&channel_id]
1869 .connection_ids
1870 .len(),
1871 2
1872 );
1873 cx_b.update(|_| drop(channel_b));
1874 server
1875 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1876 .await;
1877
1878 cx_a.update(|_| drop(channel_a));
1879 server
1880 .condition(|state| !state.channels.contains_key(&channel_id))
1881 .await;
1882 }
1883
1884 #[gpui::test]
1885 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1886 cx_a.foreground().forbid_parking();
1887
1888 let mut server = TestServer::start().await;
1889 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1890
1891 let db = &server.app_state.db;
1892 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1893 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1894 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1895 .await
1896 .unwrap();
1897 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1898 .await
1899 .unwrap();
1900
1901 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1902 channels_a
1903 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1904 .await;
1905 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1906 this.get_channel(channel_id.to_proto(), cx).unwrap()
1907 });
1908
1909 // Messages aren't allowed to be too long.
1910 channel_a
1911 .update(&mut cx_a, |channel, cx| {
1912 let long_body = "this is long.\n".repeat(1024);
1913 channel.send_message(long_body, cx).unwrap()
1914 })
1915 .await
1916 .unwrap_err();
1917
1918 // Messages aren't allowed to be blank.
1919 channel_a.update(&mut cx_a, |channel, cx| {
1920 channel.send_message(String::new(), cx).unwrap_err()
1921 });
1922
1923 // Leading and trailing whitespace are trimmed.
1924 channel_a
1925 .update(&mut cx_a, |channel, cx| {
1926 channel
1927 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1928 .unwrap()
1929 })
1930 .await
1931 .unwrap();
1932 assert_eq!(
1933 db.get_channel_messages(channel_id, 10, None)
1934 .await
1935 .unwrap()
1936 .iter()
1937 .map(|m| &m.body)
1938 .collect::<Vec<_>>(),
1939 &["surrounded by whitespace"]
1940 );
1941 }
1942
1943 #[gpui::test]
1944 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1945 cx_a.foreground().forbid_parking();
1946
1947 // Connect to a server as 2 clients.
1948 let mut server = TestServer::start().await;
1949 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1950 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1951 let mut status_b = client_b.status();
1952
1953 // Create an org that includes these 2 users.
1954 let db = &server.app_state.db;
1955 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1956 db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1957 .await
1958 .unwrap();
1959 db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1960 .await
1961 .unwrap();
1962
1963 // Create a channel that includes all the users.
1964 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1965 db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1966 .await
1967 .unwrap();
1968 db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1969 .await
1970 .unwrap();
1971 db.create_channel_message(
1972 channel_id,
1973 current_user_id(&user_store_b, &cx_b),
1974 "hello A, it's B.",
1975 OffsetDateTime::now_utc(),
1976 2,
1977 )
1978 .await
1979 .unwrap();
1980
1981 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1982 channels_a
1983 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1984 .await;
1985
1986 channels_a.read_with(&cx_a, |list, _| {
1987 assert_eq!(
1988 list.available_channels().unwrap(),
1989 &[ChannelDetails {
1990 id: channel_id.to_proto(),
1991 name: "test-channel".to_string()
1992 }]
1993 )
1994 });
1995 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1996 this.get_channel(channel_id.to_proto(), cx).unwrap()
1997 });
1998 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1999 channel_a
2000 .condition(&cx_a, |channel, _| {
2001 channel_messages(channel)
2002 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2003 })
2004 .await;
2005
2006 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2007 channels_b
2008 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2009 .await;
2010 channels_b.read_with(&cx_b, |list, _| {
2011 assert_eq!(
2012 list.available_channels().unwrap(),
2013 &[ChannelDetails {
2014 id: channel_id.to_proto(),
2015 name: "test-channel".to_string()
2016 }]
2017 )
2018 });
2019
2020 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2021 this.get_channel(channel_id.to_proto(), cx).unwrap()
2022 });
2023 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2024 channel_b
2025 .condition(&cx_b, |channel, _| {
2026 channel_messages(channel)
2027 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2028 })
2029 .await;
2030
2031 // Disconnect client B, ensuring we can still access its cached channel data.
2032 server.forbid_connections();
2033 server.disconnect_client(current_user_id(&user_store_b, &cx_b));
2034 while !matches!(
2035 status_b.recv().await,
2036 Some(rpc::Status::ReconnectionError { .. })
2037 ) {}
2038
2039 channels_b.read_with(&cx_b, |channels, _| {
2040 assert_eq!(
2041 channels.available_channels().unwrap(),
2042 [ChannelDetails {
2043 id: channel_id.to_proto(),
2044 name: "test-channel".to_string()
2045 }]
2046 )
2047 });
2048 channel_b.read_with(&cx_b, |channel, _| {
2049 assert_eq!(
2050 channel_messages(channel),
2051 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2052 )
2053 });
2054
2055 // Send a message from client B while it is disconnected.
2056 channel_b
2057 .update(&mut cx_b, |channel, cx| {
2058 let task = channel
2059 .send_message("can you see this?".to_string(), cx)
2060 .unwrap();
2061 assert_eq!(
2062 channel_messages(channel),
2063 &[
2064 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2065 ("user_b".to_string(), "can you see this?".to_string(), true)
2066 ]
2067 );
2068 task
2069 })
2070 .await
2071 .unwrap_err();
2072
2073 // Send a message from client A while B is disconnected.
2074 channel_a
2075 .update(&mut cx_a, |channel, cx| {
2076 channel
2077 .send_message("oh, hi B.".to_string(), cx)
2078 .unwrap()
2079 .detach();
2080 let task = channel.send_message("sup".to_string(), cx).unwrap();
2081 assert_eq!(
2082 channel_messages(channel),
2083 &[
2084 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2085 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2086 ("user_a".to_string(), "sup".to_string(), true)
2087 ]
2088 );
2089 task
2090 })
2091 .await
2092 .unwrap();
2093
2094 // Give client B a chance to reconnect.
2095 server.allow_connections();
2096 cx_b.foreground().advance_clock(Duration::from_secs(10));
2097
2098 // Verify that B sees the new messages upon reconnection, as well as the message client B
2099 // sent while offline.
2100 channel_b
2101 .condition(&cx_b, |channel, _| {
2102 channel_messages(channel)
2103 == [
2104 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2105 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2106 ("user_a".to_string(), "sup".to_string(), false),
2107 ("user_b".to_string(), "can you see this?".to_string(), false),
2108 ]
2109 })
2110 .await;
2111
2112 // Ensure client A and B can communicate normally after reconnection.
2113 channel_a
2114 .update(&mut cx_a, |channel, cx| {
2115 channel.send_message("you online?".to_string(), cx).unwrap()
2116 })
2117 .await
2118 .unwrap();
2119 channel_b
2120 .condition(&cx_b, |channel, _| {
2121 channel_messages(channel)
2122 == [
2123 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2124 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2125 ("user_a".to_string(), "sup".to_string(), false),
2126 ("user_b".to_string(), "can you see this?".to_string(), false),
2127 ("user_a".to_string(), "you online?".to_string(), false),
2128 ]
2129 })
2130 .await;
2131
2132 channel_b
2133 .update(&mut cx_b, |channel, cx| {
2134 channel.send_message("yep".to_string(), cx).unwrap()
2135 })
2136 .await
2137 .unwrap();
2138 channel_a
2139 .condition(&cx_a, |channel, _| {
2140 channel_messages(channel)
2141 == [
2142 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2143 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2144 ("user_a".to_string(), "sup".to_string(), false),
2145 ("user_b".to_string(), "can you see this?".to_string(), false),
2146 ("user_a".to_string(), "you online?".to_string(), false),
2147 ("user_b".to_string(), "yep".to_string(), false),
2148 ]
2149 })
2150 .await;
2151 }
2152
2153 struct TestServer {
2154 peer: Arc<Peer>,
2155 app_state: Arc<AppState>,
2156 server: Arc<Server>,
2157 notifications: mpsc::Receiver<()>,
2158 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2159 forbid_connections: Arc<AtomicBool>,
2160 _test_db: TestDb,
2161 }
2162
2163 impl TestServer {
2164 async fn start() -> Self {
2165 let test_db = TestDb::new();
2166 let app_state = Self::build_app_state(&test_db).await;
2167 let peer = Peer::new();
2168 let notifications = mpsc::channel(128);
2169 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2170 Self {
2171 peer,
2172 app_state,
2173 server,
2174 notifications: notifications.1,
2175 connection_killers: Default::default(),
2176 forbid_connections: Default::default(),
2177 _test_db: test_db,
2178 }
2179 }
2180
2181 async fn create_client(
2182 &mut self,
2183 cx: &mut TestAppContext,
2184 name: &str,
2185 ) -> (Arc<Client>, ModelHandle<UserStore>) {
2186 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2187 let client_name = name.to_string();
2188 let mut client = Client::new();
2189 let server = self.server.clone();
2190 let connection_killers = self.connection_killers.clone();
2191 let forbid_connections = self.forbid_connections.clone();
2192 Arc::get_mut(&mut client)
2193 .unwrap()
2194 .override_authenticate(move |cx| {
2195 cx.spawn(|_| async move {
2196 let access_token = "the-token".to_string();
2197 Ok(Credentials {
2198 user_id: user_id.0 as u64,
2199 access_token,
2200 })
2201 })
2202 })
2203 .override_establish_connection(move |credentials, cx| {
2204 assert_eq!(credentials.user_id, user_id.0 as u64);
2205 assert_eq!(credentials.access_token, "the-token");
2206
2207 let server = server.clone();
2208 let connection_killers = connection_killers.clone();
2209 let forbid_connections = forbid_connections.clone();
2210 let client_name = client_name.clone();
2211 cx.spawn(move |cx| async move {
2212 if forbid_connections.load(SeqCst) {
2213 Err(EstablishConnectionError::other(anyhow!(
2214 "server is forbidding connections"
2215 )))
2216 } else {
2217 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2218 connection_killers.lock().insert(user_id, kill_conn);
2219 cx.background()
2220 .spawn(server.handle_connection(server_conn, client_name, user_id))
2221 .detach();
2222 Ok(client_conn)
2223 }
2224 })
2225 });
2226
2227 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2228 client
2229 .authenticate_and_connect(&cx.to_async())
2230 .await
2231 .unwrap();
2232
2233 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2234 let mut authed_user =
2235 user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2236 while authed_user.recv().await.unwrap().is_none() {}
2237
2238 (client, user_store)
2239 }
2240
2241 fn disconnect_client(&self, user_id: UserId) {
2242 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2243 let _ = kill_conn.try_send(Some(()));
2244 }
2245 }
2246
2247 fn forbid_connections(&self) {
2248 self.forbid_connections.store(true, SeqCst);
2249 }
2250
2251 fn allow_connections(&self) {
2252 self.forbid_connections.store(false, SeqCst);
2253 }
2254
2255 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2256 let mut config = Config::default();
2257 config.session_secret = "a".repeat(32);
2258 config.database_url = test_db.url.clone();
2259 let github_client = github::AppClient::test();
2260 Arc::new(AppState {
2261 db: test_db.db().clone(),
2262 handlebars: Default::default(),
2263 auth_client: auth::build_client("", ""),
2264 repo_client: github::RepoClient::test(&github_client),
2265 github_client,
2266 config,
2267 })
2268 }
2269
2270 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2271 self.server.state.read().await
2272 }
2273
2274 async fn condition<F>(&mut self, mut predicate: F)
2275 where
2276 F: FnMut(&ServerState) -> bool,
2277 {
2278 async_std::future::timeout(Duration::from_millis(500), async {
2279 while !(predicate)(&*self.server.state.read().await) {
2280 self.notifications.recv().await;
2281 }
2282 })
2283 .await
2284 .expect("condition timed out");
2285 }
2286 }
2287
2288 impl Drop for TestServer {
2289 fn drop(&mut self) {
2290 task::block_on(self.peer.reset());
2291 }
2292 }
2293
2294 fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2295 UserId::from_proto(
2296 user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2297 )
2298 }
2299
2300 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2301 channel
2302 .messages()
2303 .cursor::<(), ()>()
2304 .map(|m| {
2305 (
2306 m.sender.github_login.clone(),
2307 m.body.clone(),
2308 m.is_pending(),
2309 )
2310 })
2311 .collect()
2312 }
2313
2314 struct EmptyView;
2315
2316 impl gpui::Entity for EmptyView {
2317 type Event = ();
2318 }
2319
2320 impl gpui::View for EmptyView {
2321 fn ui_name() -> &'static str {
2322 "empty view"
2323 }
2324
2325 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2326 gpui::Element::boxed(gpui::elements::Empty)
2327 }
2328 }
2329}